A simple training method in Pytorch

Introduction

This article displays a very simple training process of a deep learning algorithm, which is implemented using Pytorch. The code for the training function is shown below. The training function is called fit and it accepts 6 arguments.

def fit(opts, model, train_dataloader, criterion, optimizer, scheduler):
    for epoch in range(opts.num_epochs):
        model.train()
        running_loss = 0.0
        tk0 = tqdm(train_dataloader, total=int(len(train_dataloader)))
        counter = 0
        for bi, (X, y) in enumerate(tk0):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            with torch.set_grad_enabled(True):
                y_pred = model(X)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
                scheduler.step()
            running_loss += loss.item() * X.size(0)
            counter += 1
            tk0.set_postfix(loss=(running_loss / (counter * train_dataloader.batch_size))) #type: ignore
        epoch_loss = running_loss / len(train_dataloader)
        print(f"Training loss: {epoch_loss}")

Code description

The first argument is named 'opts'. It is basically a python 'dict' containing all the option required for training, such as hyperparameters. The method starts a loop that runs for the number of epochs specified by opts.num_epochs.

The second argument is that model that will undergo training. The third argument is called a dataloader, which is used to split the dataset into batches to be fed to the model. criterion takes in the predicted labels and the actual labels and calculates the loss value. This loss value is the used to compute the gradients of the model parameters. The optimizer takes these gradients and updates the model parameters. A scheduler can be used to change certain hyperparameters, in this case the learning rate is reduced once the training loss value starts to stabilize, so that the model can learn more efficiently.

Epoch stands for the number of training iterations that will be conducted. During each epoch, the model is set to training mode with the model.train() call. When the model is in training mode, the gradients can be used to update the model parameters. Then, the running loss is initialized to zero. The running loss is the average loss per training batch. The method call optimizer.zero_grad() clears any previously stored gradients which might affect the parameters.

After that another loop is started, which obtains a training batch from the dataloader, passes it through the model and gets the predicted output. Then the loss is calculated based on the outputs. The line loss.backward() calculates the gradients of the model parameters which the optimizer uses to update those parameters. The parameter update is performed using the optimizer.step() call. The scheduler.step() call reduces the learning rate if needed. loss.item() stands for the loss obtained for one input sample, or in this case one training image. This per image loss is then multiplied with the batch size, X.size(0), to calculate the total loss for the batch. After the training finishes, the average loss for one batch is printed out to the terminal. The tqdm library is used to display a progress bar of the training process.

Conclusion

This post demonstrated a very simple training method using the Pytorch library. For more information on the concepts discussed here, visit the Pytorch docs. This post is also inspired from this kaggle notebook. I hope you enjoyed reading this post and follow this blog if you want to know more about concepts in AI.