Creating a custom Dataset class in Pytorch
Require data handling before model training
Before creating deep-learning models and training them, the dataset needs to be prepared to be fed into the model training scheme. This article focuses on building a custom dataset class, with the functionality of converting images to tensors. This post is part of a series that focuses on writing a simple image classifier using Pytorch. If you dont have pytorch installed, you can flow this guide to do so. The dataset used for this example is taken from the Kaggle APTOS 2019 Blindness Detection competition. The code is also a modification of this notebook and the official docs.
Step 1: Import statements
import pandas as pd
import os
from PIL import Image, ImageFile
import torch
from torch.utils.data import Dataset
from torchvision import transforms
The libraries that are used as dependencies for the custom dataset class are pandas, PIL, os and Pytorch. The pandas library is used to load the 'train.csv' file which contains the id of all the images and the target labels. The Pillow library or PIL is used for operations such as reading an image and resizing it. The os library is used to execute commands on the operating system, for our case it is used to derive specific image paths. Finally, we use the pytorch library, specifically it imports the Dataset class and the transforms library.
Step 2: Custom Dataset class
train_image_path = './data/blindness_detection/aptos2019-blindness-detection/train_images/'
train_csv_path = './data/blindness_detection/aptos2019-blindness-detection/train.csv'
class BlindnessDatasetTrain(Dataset):
def __init__(self, input_csv_path):
self.data = pd.read_csv(input_csv_path)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
train_image_file = os.path.join(
train_image_path,
self.data.loc[idx, 'id_code'] + '.png')
image = Image.open(train_image_file)
image = image.resize((256, 256), resample=Image.BILINEAR)
image = transforms.ToTensor()(image)
label = torch.tensor(self.data.loc[idx, 'diagnosis'])
return image, label
Firstly, there are two variables pointing towards the path of the 'train.csv' and to the folder containing the training images. These variables are used inside the dataset class to read the images and the labels, and to convert them to tensors. Ideally, these variables would be read from the command line or read from a config file to a dictionary and then the dictionary would be passed to the class during its initialization. But this article focuses on just the bare minimum version to simplify interpretation. In future blog posts, this class will be refactored to a more robust version.
The custom dataset class itself is a subclass of the Dataset class imported from torch.utils.data
. It has a __init__
method, which reads the csv file into a pandas dataframe. It then implements two methods named __len__
and __getitem__
. These implementations are required by the superclass and failing to define them would raise NotImplementedError()
for the __getitem__
method, and TypeError
for __len__
method when calculating the size of the dataset. The __len__
method simply returns the length of the dataset. This is also a representation of the total number of training images present in the training set.
The __getitem__
method receives an index as its parameter. It then reads the image and the label based on that index, converts them to tensors and finally returns the image and label tensors. Note that the line that states image = transforms.ToTensor()(image)
, here transforms.ToTensor()
is not a function call, rather it is an object initialization of the ToTensor
class. The __call__
method of the object is then called and the image is passed to it, which then ultimately returns the image tensors. Also note that the returned tensor is following the 'chw' convention where the channel comes first before the height and the width of the image. Finally, to create an instance of the custom dataset class, simply call the class name as shown below:
def main():
sampleDataset = BlindnessDatasetTrain(
input_csv_path=train_csv_path
)
if __name__ == "__main__":
main()
Conclusion
Creating a custom dataset class is the first step in forming a training pipeline for a deep learning model. This article defined a simple process of doing so. This post is part of a series that focuses on building a simple image classifier. This post is inspired from this kaggle notebook and the official data tutorial by pytorch. For learning further, it is recommended to also check out the source code for the pytorch dataset class. Thanks for reading, if you liked it make sure to follow the blog and leave a comment :D.