How to write custom callbacks for AI training
Introduction
Callbacks are essentially a set of functions that are executed at different stages of a program lifecycle. In the context of training deep learning models, callbacks can be used at different stages of the training process. Some examples would be running callbacks before the training starts, after a batch of inputs is passed to the model, after the optimizer updates the gradient, etc.
This post demonstrates a CallbackRepo
class that has python dictionary as its attribute. This dictionary contains lists of callback functions to execute at different stages of the program. These different stages are also referred to as hooks. The code in this article is taken and modified from the yolov5 github repo.
CallbackRepo
The CallbackRepo
essentially has four methods including its initializer. The code for the __init__
method is shown below:
class CallbackRepo:
"""
Handles all registered callbacks for a deeplearning model
"""
def __init__(self):
# Define the available callbacks
self._callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],
'on_train_start': [],
'on_train_epoch_start': [],
'on_train_batch_start': [],
'optimizer_step': [],
'on_before_zero_grad': [],
'on_train_batch_end': [],
'on_train_epoch_end': [],
'on_val_start': [],
'on_val_batch_start': [],
'on_val_image_end': [],
'on_val_batch_end': [],
'on_val_end': [],
'on_fit_epoch_end': [], # fit = train + val
'on_model_save': [],
'on_train_end': [],
'on_params_update': [],
'teardown': [],}
self.stop_training = False # set True to interrupt training
The initializer sets two parameters for the class called _callbacks
and stop_training
, where the first one is a dictionary that will hold all the callbacks for their respective hooks and the latter is a flag that signals to stop the model training.
After the initializer, the class has a method to register the callback functions to the respective hooks. The code for that is shown below.
...
def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook
Args:
hook: The callback hook name to register the action to
name: The name of the action for later reference
callback: The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({'name': name, 'callback': callback})
The register_action
method takes in a hook, a name for the callback and a callback function. It first checks if the hook passed to the function is a valid hook. It then checks if the callback function passed is a callable. After that it appends the callback to the list in the self._callbacks
dictionary for that particular hook.
To obtain all the assigned callbacks for a specific hook, the class implements a get_registered_actions
method. This method takes in a hook and returns all the assigned callbacks if the hook is a valid one. The code for this function is shown below:
...
def get_registered_actions(self, hook=None):
""""
Returns all the registered actions by callback hook
Args:
hook: The name of the hook to check, defaults to all
"""
return self._callbacks[hook] if hook else self._callbacks
Finally the class implements a __call__
method which makes the instance of the classs callable. The code for this method is shown below:
...
def __call__(self, hook, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
Args:
hook: The name of the hook to check, defaults to all
args: Arguments to receive from a deeplearning model
kwargs: Keyword Arguments to receive from a deeplearning model
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
for logger in self._callbacks[hook]:
logger['callback'](*args, **kwargs)
The method takes in a hook and all the arguments and keyword arguments necessary for the callback functions of that hook. It first checks if the hook is valid and then calls each callback method one by one passing in the arguments and keyword arguments. The full code of this class is shown below:
#Taken and modified from: https://github.com/ultralytics/yolov5/blob/master/utils/callbacks.py
class CallbackRepo:
"""
Handles all registered callbacks for a deeplearning model
"""
def __init__(self):
# Define the available callbacks
self._callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],
'on_train_start': [],
'on_train_epoch_start': [],
'on_train_batch_start': [],
'optimizer_step': [],
'on_before_zero_grad': [],
'on_train_batch_end': [],
'on_train_epoch_end': [],
'on_val_start': [],
'on_val_batch_start': [],
'on_val_image_end': [],
'on_val_batch_end': [],
'on_val_end': [],
'on_fit_epoch_end': [], # fit = train + val
'on_model_save': [],
'on_train_end': [],
'on_params_update': [],
'teardown': [],}
self.stop_training = False # set True to interrupt training
def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook
Args:
hook: The callback hook name to register the action to
name: The name of the action for later reference
callback: The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({'name': name, 'callback': callback})
def get_registered_actions(self, hook=None):
""""
Returns all the registered actions by callback hook
Args:
hook: The name of the hook to check, defaults to all
"""
return self._callbacks[hook] if hook else self._callbacks
def __call__(self, hook, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
Args:
hook: The name of the hook to check, defaults to all
args: Arguments to receive from a deeplearning model
kwargs: Keyword Arguments to receive from a deeplearning model
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
for logger in self._callbacks[hook]:
logger['callback'](*args, **kwargs)
Loggers
This section displays a Loggers
class which is essentially responsible for logging the training metrics, model structure and sample images to tensorboard. It also logs the metrics to a csv file. The code is shown below:
import os
import warnings
import torch
import torchvision
from torch.utils.tensorboard.writer import SummaryWriter
import cv2
LOGGERS = ('csv', 'tb')
class Loggers():
# YOLOv5 Loggers class
def __init__(self, save_dir=None, logger=None, include=LOGGERS):
self.save_dir = save_dir
self.logger = logger # for printing results to console
self.include = include
for k in LOGGERS:
setattr(self, k, None) # init empty logger dictionary
self.csv = True # always log to csv
# TensorBoard
s = self.save_dir
if 'tb' in self.include:
prefix = 'TensorBoard: '
self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
self.tb = SummaryWriter(str(s))
def on_train_start(self):
# Callback runs on train start
pass
def on_pretrain_routine_end(self):
pass
def on_train_batch_end(self, epoch, batch_index, model, imgs, plots=True):
# Callback runs on train batch end
if plots:
if epoch == 0 and batch_index == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
grid = torchvision.utils.make_grid(imgs)
self.tb.add_image('images', grid, 0)
self.tb.add_graph(torch.jit.trace(model, imgs[0:1], strict=False), []) #type: ignore
def on_train_epoch_end(self, epoch):
pass
def on_val_image_end(self, pred, predn, path, names, im):
pass
def on_val_end(self):
pass
def on_fit_epoch_end(self, metrics, epoch):
# Callback runs at the end of each fit (train+val) epoch
# x = {k: v for k, v in zip(self.keys, vals)} # dict
if self.csv:
file = self.save_dir / 'results.csv'
n = len(metrics) + 1 # number of cols
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + list(metrics.keys()))).rstrip(',') + '\n') # add header
with open(file, 'a') as f:
f.write(s + ('%20.5g,' * n % tuple([epoch] + list(metrics.values()))).rstrip(',') + '\n')
if self.tb:
for k, v in metrics.items():
self.tb.add_scalar(k, v, epoch)
def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
pass
def on_train_end(self):
self.tb.close()
def on_params_update(self, params):
pass
Note that a lot a lot of methods are left empty and functionality to these methods will be added in future blogposts. For now this class basically does two things, at the end of the very first batch, the model and images are logged to tensorboard. This is done in the on_train_batch_end
method. In the on_fit_epoch_end
method, metrics are logged to a csv file and to tensorboard.
After defining the classes, the callbacks are registered using the code below:
callbacks = CallbackRepo()
loggers = Loggers(save_dir, LOGGER)
for k in methods(loggers):
callbacks.register_action(k, callback=getattr(loggers, k))
Here methods(loggers)
returns all the attribute in the loggers instance which are callable. This is done by the code:
# Modified from https://github.com/ultralytics/yolov5/blob/master/utils/general.py
def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
Finally, the callbacks are executed during different stages of training. An example of the use of callbacks is shown below:
...
# Training code
for bi, (X, y) in enumerate(tk0):
# Code for a forward and a backward pass of the model
...
# Run callback to log the model and images to tensorboard
callbacks('on_train_batch_end', epoch, bi, model, X)
#Rest of the code
...
Conclusion
Callbacks are an effective way of adding functionality to a process. It reduces coupling, in this case the callback functions can be replaced without changing the training method. I also simplifies the code and enables the methods to focus on one task at a time. This post showed the concepts of callbacks and an example of their use. I hope you like it and found it helpful.