def foo(x, callback=None):
print('foo!')
if callback:
callback(x)
These series are written as a quick introduction to software design for data scientists, something that is lightweight than the Design Pattern Bible - Clean Code I wish exists when I first started to learn. Design patterns refer to reusable solutions to some common problems, and some happen to be useful for data science. There is a good chance that someone else has solved your problem before. When used wisely, it helps to reduce the complexity of your code.
So, What is Callback after all?
Callback
function, or call after, simply means a function will be called after another function. It is a piece of executable code (function) that passed as an argument to another function. [1]
'123') foo(
foo!
'123', print) foo(
foo!
123
Here I pass the function print
as a callback, hence the string 123
get printed after foo!
.
Why do I need to use Callback?
Callback is very common in high-level deep learning libraries, most likely you will find them in the training loop. * fastai - fastai provide high-level API for PyTorch * Keras - the high-level API for Tensorflow * ignite - they use event & handler, which provides more flexibility in their opinion
import numpy as np
# A boring training Loop
def train(x):
= 3
n_epochs = 2
n_batches = 20
loss
for epoch in range(n_epochs):
for batch in range(n_batches):
= loss - 1 # Pretend we are training the model loss
= np.ones(10)
x ; train(x)
So, let’s say you now want to print the loss at the end of an epoch. You can just add 1 lines of code.
The simple approach
def train_with_print(x):
= 3
n_epochs = 2
n_batches = 20
loss
for epoch in range(n_epochs):
for batch in range(n_batches):
= loss - 1 # Pretend we are training the model
loss print(f'End of Epoch. Epoch: {epoch}, Loss: {loss}')
return loss
; train_with_print(x)
End of Epoch. Epoch: 0, Loss: 18
End of Epoch. Epoch: 1, Loss: 16
End of Epoch. Epoch: 2, Loss: 14
Callback approach
Or you call add a PrintCallback, which does the same thing but with a bit more code.
class Callback:
def on_epoch_start(self, x):
pass
def on_epoch_end(self, x):
pass
def on_batch_start(self, x):
pass
def on_batch_end(self, x):
pass
class PrintCallback(Callback):
def on_epoch_end(self, x):
print(f'End of Epoch. Loss: {x}')
def train_with_callback(x, callback=None):
= 3
n_epochs = 2
n_batches = 20
loss
for epoch in range(n_epochs):
callback.on_epoch_start(loss)
for batch in range(n_batches):
callback.on_batch_start(loss)= loss - 1 # Pretend we are training the model
loss
callback.on_batch_end(loss)
callback.on_epoch_end(loss)
=PrintCallback()); train_with_callback(x, callback
End of Epoch. Loss: 18
End of Epoch. Loss: 16
End of Epoch. Loss: 14
Usually, a callback defines a few particular events on_xxx_xxx
, which indicate that the function will be executed according to the corresponding condition. So all callbacks will inherit the base class Callback
, and override the desired function, here we only implemented the on_epoch_end
method because we only want to show the loss at the end.
It may seem awkward to write so many more code to do one simple thing, but there are good reasons. Consider now you need to add more features, how would you do it?
- ModelCheckpoint
- Early Stopping
- LearningRateScheduler
You can just add code in the loop, but it will start growing into a really giant function. It is impossible to test this function because it does 10 things at the same time. In addition, the extra code may not even be related to the training logic, they are just there to save the model or plot a chart. So, it is best to separate the logic. A function should only do 1 thing according to the Single Responsibility Principle. It helps you to reduce the complexity as it provides a nice abstraction, you are only modifying code within the specific callback you are interested.
Add some more sauce!
When using the Callback Pattern, I can just implement a few more classes and the training loop is barely touched. Here we introduce a new class Callbacks
because we need to execute more than 1 callback, it is used for holding all callbacks and executed them sequentially.
class Callbacks:
"""
It is the container for callback
"""
def __init__(self, callbacks):
self.callbacks = callbacks
def on_epoch_start(self, x):
for callback in self.callbacks:
callback.on_epoch_start(x)
def on_epoch_end(self, x):
for callback in self.callbacks:
callback.on_epoch_end(x)
def on_batch_start(self, x):
for callback in self.callbacks:
callback.on_batch_start(x)
def on_batch_end(self, x):
for callback in self.callbacks:
callback.on_batch_end(x)
Then we implement the new Callback one by one, here we only have the pseudocode, but you should get the gist. For example, we only need to save the model at the end of an epoch, thus we implement the method on_epoch_end
with a ModelCheckPoint
callback.
class PrintCallback(Callback):
def on_epoch_end(self, x):
print(f'[{type(self).__name__}]: End of Epoch. Loss: {x}')
class ModelCheckPoint(Callback):
def on_epoch_end(self, x):
print(f'[{type(self).__name__}]: Save Model')
class EarlyStoppingCallback(Callback):
def on_epoch_end(self, x):
if x < 16:
print(f'[{type(self).__name__}]: Early Stopped')
class LearningRateScheduler(Callback):
def on_batch_end(self, x):
print(f' [{type(self).__name__}]: Reduce learning rate')
And we also modify the training loop a bit, the argument now takes a Callbacks
which contain zero to many callbacks.
def train_with_callbacks(x, callbacks=None):
= 2
n_epochs = 3
n_batches = 20
loss
for epoch in range(n_epochs):
# on_epoch_start
callbacks.on_epoch_start(loss) for batch in range(n_batches):
# on_batch_start
callbacks.on_batch_start(loss) = loss - 1 # Pretend we are training the model
loss # on_batch_end
callbacks.on_batch_end(loss) # on_epoch_end callbacks.on_epoch_end(loss)
= Callbacks([PrintCallback(), ModelCheckPoint(),
callbacks
EarlyStoppingCallback(), LearningRateScheduler()])=callbacks) train_with_callbacks(x, callbacks
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[PrintCallback]: End of Epoch. Loss: 17
[ModelCheckPoint]: Save Model
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[LearningRateScheduler]: Reduce learning rate
[PrintCallback]: End of Epoch. Loss: 14
[ModelCheckPoint]: Save Model
[EarlyStoppingCallback]: Early Stopped
Hopefully, it convinces you Callback makes the code cleaner and easier to maintain. If you just use plain if-else
statements, you may end up with a big chunk of if-else
clauses.
Reference
- https://stackoverflow.com/questions/824234/what-is-a-callback-function