How to write custom callbacks for AI training

Introduction

Callbacks are essentially a set of functions that are executed at different stages of a program lifecycle. In the context of training deep learning models, callbacks can be used at different stages of the training process. Some examples would be running callbacks before the training starts, after a batch of inputs is passed to the model, after the optimizer updates the gradient, etc.

This post demonstrates a CallbackRepo class that has python dictionary as its attribute. This dictionary contains lists of callback functions to execute at different stages of the program. These different stages are also referred to as hooks. The code in this article is taken and modified from the yolov5 github repo.

CallbackRepo

The CallbackRepo essentially has four methods including its initializer. The code for the __init__ method is shown below:

class CallbackRepo:
    """
    Handles all registered callbacks for a deeplearning model
    """

    def __init__(self):
        # Define the available callbacks
        self._callbacks = {
            'on_pretrain_routine_start': [],
            'on_pretrain_routine_end': [],
            'on_train_start': [],
            'on_train_epoch_start': [],
            'on_train_batch_start': [],
            'optimizer_step': [],
            'on_before_zero_grad': [],
            'on_train_batch_end': [],
            'on_train_epoch_end': [],
            'on_val_start': [],
            'on_val_batch_start': [],
            'on_val_image_end': [],
            'on_val_batch_end': [],
            'on_val_end': [],
            'on_fit_epoch_end': [],  # fit = train + val
            'on_model_save': [],
            'on_train_end': [],
            'on_params_update': [],
            'teardown': [],}
        self.stop_training = False  # set True to interrupt training

The initializer sets two parameters for the class called _callbacks and stop_training, where the first one is a dictionary that will hold all the callbacks for their respective hooks and the latter is a flag that signals to stop the model training.

After the initializer, the class has a method to register the callback functions to the respective hooks. The code for that is shown below.

...
    def register_action(self, hook, name='', callback=None):
        """
        Register a new action to a callback hook
        Args:
            hook: The callback hook name to register the action to
            name: The name of the action for later reference
            callback: The callback to fire
        """
        assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
        assert callable(callback), f"callback '{callback}' is not callable"
        self._callbacks[hook].append({'name': name, 'callback': callback})

The register_action method takes in a hook, a name for the callback and a callback function. It first checks if the hook passed to the function is a valid hook. It then checks if the callback function passed is a callable. After that it appends the callback to the list in the self._callbacks dictionary for that particular hook.

To obtain all the assigned callbacks for a specific hook, the class implements a get_registered_actions method. This method takes in a hook and returns all the assigned callbacks if the hook is a valid one. The code for this function is shown below:

...
    def get_registered_actions(self, hook=None):
        """"
        Returns all the registered actions by callback hook
        Args:
            hook: The name of the hook to check, defaults to all
        """
        return self._callbacks[hook] if hook else self._callbacks

Finally the class implements a __call__ method which makes the instance of the classs callable. The code for this method is shown below:

...
    def __call__(self, hook, *args, **kwargs):
        """
        Loop through the registered actions and fire all callbacks
        Args:
            hook: The name of the hook to check, defaults to all
            args: Arguments to receive from a deeplearning model
            kwargs: Keyword Arguments to receive from a deeplearning model
        """

        assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"

        for logger in self._callbacks[hook]:
            logger['callback'](*args, **kwargs)

The method takes in a hook and all the arguments and keyword arguments necessary for the callback functions of that hook. It first checks if the hook is valid and then calls each callback method one by one passing in the arguments and keyword arguments. The full code of this class is shown below:

#Taken and modified from: https://github.com/ultralytics/yolov5/blob/master/utils/callbacks.py
class CallbackRepo:
    """
    Handles all registered callbacks for a deeplearning model
    """

    def __init__(self):
        # Define the available callbacks
        self._callbacks = {
            'on_pretrain_routine_start': [],
            'on_pretrain_routine_end': [],
            'on_train_start': [],
            'on_train_epoch_start': [],
            'on_train_batch_start': [],
            'optimizer_step': [],
            'on_before_zero_grad': [],
            'on_train_batch_end': [],
            'on_train_epoch_end': [],
            'on_val_start': [],
            'on_val_batch_start': [],
            'on_val_image_end': [],
            'on_val_batch_end': [],
            'on_val_end': [],
            'on_fit_epoch_end': [],  # fit = train + val
            'on_model_save': [],
            'on_train_end': [],
            'on_params_update': [],
            'teardown': [],}
        self.stop_training = False  # set True to interrupt training

    def register_action(self, hook, name='', callback=None):
        """
        Register a new action to a callback hook
        Args:
            hook: The callback hook name to register the action to
            name: The name of the action for later reference
            callback: The callback to fire
        """
        assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
        assert callable(callback), f"callback '{callback}' is not callable"
        self._callbacks[hook].append({'name': name, 'callback': callback})

    def get_registered_actions(self, hook=None):
        """"
        Returns all the registered actions by callback hook
        Args:
            hook: The name of the hook to check, defaults to all
        """
        return self._callbacks[hook] if hook else self._callbacks

    def __call__(self, hook, *args, **kwargs):
        """
        Loop through the registered actions and fire all callbacks
        Args:
            hook: The name of the hook to check, defaults to all
            args: Arguments to receive from a deeplearning model
            kwargs: Keyword Arguments to receive from a deeplearning model
        """

        assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"

        for logger in self._callbacks[hook]:
            logger['callback'](*args, **kwargs)

Loggers

This section displays a Loggers class which is essentially responsible for logging the training metrics, model structure and sample images to tensorboard. It also logs the metrics to a csv file. The code is shown below:

import os
import warnings

import torch
import torchvision
from torch.utils.tensorboard.writer import SummaryWriter

import cv2

LOGGERS = ('csv', 'tb')
class Loggers():
    # YOLOv5 Loggers class
    def __init__(self, save_dir=None, logger=None, include=LOGGERS):
        self.save_dir = save_dir
        self.logger = logger  # for printing results to console
        self.include = include
        for k in LOGGERS:
            setattr(self, k, None)  # init empty logger dictionary
        self.csv = True  # always log to csv

        # TensorBoard
        s = self.save_dir
        if 'tb' in self.include:
            prefix = 'TensorBoard: '
            self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
            self.tb = SummaryWriter(str(s))


    def on_train_start(self):
        # Callback runs on train start
        pass

    def on_pretrain_routine_end(self):
        pass

    def on_train_batch_end(self, epoch, batch_index, model, imgs, plots=True):
        # Callback runs on train batch end
        if plots:
            if epoch == 0 and batch_index == 0:
                with warnings.catch_warnings():
                    warnings.simplefilter('ignore')  # suppress jit trace warning
                    grid = torchvision.utils.make_grid(imgs)
                    self.tb.add_image('images', grid, 0)
                    self.tb.add_graph(torch.jit.trace(model, imgs[0:1], strict=False), []) #type: ignore

    def on_train_epoch_end(self, epoch):
        pass

    def on_val_image_end(self, pred, predn, path, names, im):
        pass

    def on_val_end(self):
        pass

    def on_fit_epoch_end(self, metrics, epoch):
        # Callback runs at the end of each fit (train+val) epoch
        # x = {k: v for k, v in zip(self.keys, vals)}  # dict
        if self.csv:
            file = self.save_dir / 'results.csv'
            n = len(metrics) + 1  # number of cols
            s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + list(metrics.keys()))).rstrip(',') + '\n')  # add header
            with open(file, 'a') as f:
                f.write(s + ('%20.5g,' * n % tuple([epoch] + list(metrics.values()))).rstrip(',') + '\n')

        if self.tb:
            for k, v in metrics.items():
                self.tb.add_scalar(k, v, epoch)

    def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
        pass

    def on_train_end(self):
        self.tb.close()


    def on_params_update(self, params):
        pass

Note that a lot a lot of methods are left empty and functionality to these methods will be added in future blogposts. For now this class basically does two things, at the end of the very first batch, the model and images are logged to tensorboard. This is done in the on_train_batch_end method. In the on_fit_epoch_end method, metrics are logged to a csv file and to tensorboard.

After defining the classes, the callbacks are registered using the code below:

callbacks = CallbackRepo()
loggers = Loggers(save_dir, LOGGER)

for k in methods(loggers):
    callbacks.register_action(k, callback=getattr(loggers, k))

Here methods(loggers) returns all the attribute in the loggers instance which are callable. This is done by the code:

# Modified from https://github.com/ultralytics/yolov5/blob/master/utils/general.py
def methods(instance):
    # Get class/instance methods
    return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]

Finally, the callbacks are executed during different stages of training. An example of the use of callbacks is shown below:

...
        # Training code
        for bi, (X, y) in enumerate(tk0):
            # Code for a forward and a backward pass of the model
            ...

            # Run callback to log the model and images to tensorboard
            callbacks('on_train_batch_end', epoch, bi, model, X)

          #Rest of the code
          ...

Conclusion

Callbacks are an effective way of adding functionality to a process. It reduces coupling, in this case the callback functions can be replaced without changing the training method. I also simplifies the code and enables the methods to focus on one task at a time. This post showed the concepts of callbacks and an example of their use. I hope you like it and found it helpful.