Creating a custom Dataset class in Pytorch

Require data handling before model training

Before creating deep-learning models and training them, the dataset needs to be prepared to be fed into the model training scheme. This article focuses on building a custom dataset class, with the functionality of converting images to tensors. This post is part of a series that focuses on writing a simple image classifier using Pytorch. If you dont have pytorch installed, you can flow this guide to do so. The dataset used for this example is taken from the Kaggle APTOS 2019 Blindness Detection competition. The code is also a modification of this notebook and the official docs.

Step 1: Import statements

import pandas as pd
import os

from PIL import Image, ImageFile

import torch
from torch.utils.data import Dataset
from torchvision import transforms

The libraries that are used as dependencies for the custom dataset class are pandas, PIL, os and Pytorch. The pandas library is used to load the 'train.csv' file which contains the id of all the images and the target labels. The Pillow library or PIL is used for operations such as reading an image and resizing it. The os library is used to execute commands on the operating system, for our case it is used to derive specific image paths. Finally, we use the pytorch library, specifically it imports the Dataset class and the transforms library.

Step 2: Custom Dataset class

train_image_path = './data/blindness_detection/aptos2019-blindness-detection/train_images/'
train_csv_path = './data/blindness_detection/aptos2019-blindness-detection/train.csv'

class BlindnessDatasetTrain(Dataset):
    def __init__(self, input_csv_path):
        self.data = pd.read_csv(input_csv_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        train_image_file = os.path.join(
            train_image_path,
            self.data.loc[idx, 'id_code'] + '.png')
        image = Image.open(train_image_file)
        image = image.resize((256, 256), resample=Image.BILINEAR)
        image = transforms.ToTensor()(image)
        label = torch.tensor(self.data.loc[idx, 'diagnosis'])
        return image, label

Firstly, there are two variables pointing towards the path of the 'train.csv' and to the folder containing the training images. These variables are used inside the dataset class to read the images and the labels, and to convert them to tensors. Ideally, these variables would be read from the command line or read from a config file to a dictionary and then the dictionary would be passed to the class during its initialization. But this article focuses on just the bare minimum version to simplify interpretation. In future blog posts, this class will be refactored to a more robust version.

The custom dataset class itself is a subclass of the Dataset class imported from torch.utils.data. It has a __init__ method, which reads the csv file into a pandas dataframe. It then implements two methods named __len__ and __getitem__. These implementations are required by the superclass and failing to define them would raise NotImplementedError() for the __getitem__ method, and TypeError for __len__ method when calculating the size of the dataset. The __len__ method simply returns the length of the dataset. This is also a representation of the total number of training images present in the training set.

The __getitem__ method receives an index as its parameter. It then reads the image and the label based on that index, converts them to tensors and finally returns the image and label tensors. Note that the line that states image = transforms.ToTensor()(image), here transforms.ToTensor() is not a function call, rather it is an object initialization of the ToTensor class. The __call__ method of the object is then called and the image is passed to it, which then ultimately returns the image tensors. Also note that the returned tensor is following the 'chw' convention where the channel comes first before the height and the width of the image. Finally, to create an instance of the custom dataset class, simply call the class name as shown below:

def main():
    sampleDataset = BlindnessDatasetTrain(
        input_csv_path=train_csv_path
    )

if __name__ == "__main__":
    main()

Conclusion

Creating a custom dataset class is the first step in forming a training pipeline for a deep learning model. This article defined a simple process of doing so. This post is part of a series that focuses on building a simple image classifier. This post is inspired from this kaggle notebook and the official data tutorial by pytorch. For learning further, it is recommended to also check out the source code for the pytorch dataset class. Thanks for reading, if you liked it make sure to follow the blog and leave a comment :D.