5 Minutes Data Science Design Patterns I - Callback

python
design pattern
software
A mini collections of design pattern for Data Science - Starting with callbacks.
Author

noklam

Published

July 10, 2021

Note

These series are written as a quick introduction to software design for data scientists, something that is lightweight than the Design Pattern Bible - Clean Code I wish exists when I first started to learn. Design patterns refer to reusable solutions to some common problems, and some happen to be useful for data science. There is a good chance that someone else has solved your problem before. When used wisely, it helps to reduce the complexity of your code.

So, What is Callback after all?

Callback function, or call after, simply means a function will be called after another function. It is a piece of executable code (function) that passed as an argument to another function. [1]

def foo(x, callback=None):
    print('foo!')
    if callback:
        callback(x)
foo('123')
foo!
foo('123', print)
foo!
123

Here I pass the function print as a callback, hence the string 123 get printed after foo!.

Why do I need to use Callback?

Callback is very common in high-level deep learning libraries, most likely you will find them in the training loop. * fastai - fastai provide high-level API for PyTorch * Keras - the high-level API for Tensorflow * ignite - they use event & handler, which provides more flexibility in their opinion

import numpy as np

# A boring training Loop
def train(x):
    n_epochs = 3
    n_batches = 2
    loss = 20

    for epoch in range(n_epochs):
        for batch in range(n_batches):
            loss = loss - 1  # Pretend we are training the model
x = np.ones(10)
train(x);

So, let’s say you now want to print the loss at the end of an epoch. You can just add 1 lines of code.

The simple approach

def train_with_print(x):
    n_epochs = 3
    n_batches = 2
    loss = 20

    for epoch in range(n_epochs):
        for batch in range(n_batches):
            loss = loss - 1 # Pretend we are training the model
        print(f'End of Epoch. Epoch: {epoch}, Loss: {loss}')
    return loss
train_with_print(x);
End of Epoch. Epoch: 0, Loss: 18
End of Epoch. Epoch: 1, Loss: 16
End of Epoch. Epoch: 2, Loss: 14

Callback approach

Or you call add a PrintCallback, which does the same thing but with a bit more code.

class Callback:
    def on_epoch_start(self, x):
        pass

    def on_epoch_end(self, x):
        pass

    def on_batch_start(self, x):
        pass

    def on_batch_end(self, x):
        pass


class PrintCallback(Callback):
    def on_epoch_end(self, x):
        print(f'End of Epoch. Loss: {x}')


def train_with_callback(x, callback=None):
    n_epochs = 3
    n_batches = 2
    loss = 20

    for epoch in range(n_epochs):

        callback.on_epoch_start(loss)

        for batch in range(n_batches):
            callback.on_batch_start(loss)
            loss = loss - 1  # Pretend we are training the model
            callback.on_batch_end(loss)

        callback.on_epoch_end(loss)
train_with_callback(x, callback=PrintCallback());
End of Epoch. Loss: 18
End of Epoch. Loss: 16
End of Epoch. Loss: 14

Usually, a callback defines a few particular events on_xxx_xxx, which indicate that the function will be executed according to the corresponding condition. So all callbacks will inherit the base class Callback, and override the desired function, here we only implemented the on_epoch_end method because we only want to show the loss at the end.

It may seem awkward to write so many more code to do one simple thing, but there are good reasons. Consider now you need to add more features, how would you do it?

  • ModelCheckpoint
  • Early Stopping
  • LearningRateScheduler

You can just add code in the loop, but it will start growing into a really giant function. It is impossible to test this function because it does 10 things at the same time. In addition, the extra code may not even be related to the training logic, they are just there to save the model or plot a chart. So, it is best to separate the logic. A function should only do 1 thing according to the Single Responsibility Principle. It helps you to reduce the complexity as it provides a nice abstraction, you are only modifying code within the specific callback you are interested.

Add some more sauce!

When using the Callback Pattern, I can just implement a few more classes and the training loop is barely touched. Here we introduce a new class Callbacks because we need to execute more than 1 callback, it is used for holding all callbacks and executed them sequentially.

class Callbacks:
    """
    It is the container for callback
    """

    def __init__(self, callbacks):
        self.callbacks = callbacks

    def on_epoch_start(self, x):
        for callback in self.callbacks:
            callback.on_epoch_start(x)

    def on_epoch_end(self, x):
        for callback in self.callbacks:
            callback.on_epoch_end(x)

    def on_batch_start(self, x):
        for callback in self.callbacks:
            callback.on_batch_start(x)

    def on_batch_end(self, x):
        for callback in self.callbacks:
            callback.on_batch_end(x)

Then we implement the new Callback one by one, here we only have the pseudocode, but you should get the gist. For example, we only need to save the model at the end of an epoch, thus we implement the method on_epoch_end with a ModelCheckPoint callback.

class PrintCallback(Callback):
    def on_epoch_end(self, x):
        print(f'[{type(self).__name__}]: End of Epoch. Loss: {x}')


class ModelCheckPoint(Callback):
    def on_epoch_end(self, x):
        print(f'[{type(self).__name__}]: Save Model')


class EarlyStoppingCallback(Callback):
    def on_epoch_end(self, x):
        if x < 16:
            print(f'[{type(self).__name__}]: Early Stopped')


class LearningRateScheduler(Callback):
    def on_batch_end(self, x):
        print(f'    [{type(self).__name__}]: Reduce learning rate')

And we also modify the training loop a bit, the argument now takes a Callbacks which contain zero to many callbacks.

def train_with_callbacks(x, callbacks=None):
    n_epochs = 2
    n_batches = 3
    loss = 20

    for epoch in range(n_epochs):

        callbacks.on_epoch_start(loss)                             # on_epoch_start
        for batch in range(n_batches):
            callbacks.on_batch_start(loss)                         # on_batch_start
            loss = loss - 1  # Pretend we are training the model
            callbacks.on_batch_end(loss)                           # on_batch_end
        callbacks.on_epoch_end(loss)                               # on_epoch_end
callbacks = Callbacks([PrintCallback(), ModelCheckPoint(),
                      EarlyStoppingCallback(), LearningRateScheduler()])
train_with_callbacks(x, callbacks=callbacks)
    [LearningRateScheduler]: Reduce learning rate
    [LearningRateScheduler]: Reduce learning rate
    [LearningRateScheduler]: Reduce learning rate
[PrintCallback]: End of Epoch. Loss: 17
[ModelCheckPoint]: Save Model
    [LearningRateScheduler]: Reduce learning rate
    [LearningRateScheduler]: Reduce learning rate
    [LearningRateScheduler]: Reduce learning rate
[PrintCallback]: End of Epoch. Loss: 14
[ModelCheckPoint]: Save Model
[EarlyStoppingCallback]: Early Stopped

Hopefully, it convinces you Callback makes the code cleaner and easier to maintain. If you just use plain if-else statements, you may end up with a big chunk of if-else clauses.

  • fastai - fastai provide high-level API for PyTorch
  • Keras - the high-level API for Tensorflow
  • ignite - they use event & handler, which provides more flexibility in their opinion

Reference

  1. https://stackoverflow.com/questions/824234/what-is-a-callback-function