đź“šUsing Custom Dataset in PyTorch

Oct 26, 2022 · 3 min read

In order to decouple dataset code and model code, PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

Load a Dataset

If we want to use Data, we must have Data first. Fortunately, PyTorch domain libraries provide a number of pre-loaded datasets. All of them is the subclass of theDataset. Now we use Fashion-MNIST, one of them, to show how to load a dataset.

import torch
from torch.utils.data import Dataset

from torchvision import datasets
from torchvision.transforms import import  ToTensor

training_data = datasets.FashionMNIST(
	root = "data", # the path where the data is stored
    train = True, # specifies training or test dataset
    download = True, # downloads the data from internet if not available at root
    transform = ToTensor() # specify the feature and label transformation
)

2.2 Iterating and Visualizing the Dataset

We can index Datasets manually like a list:training_data[index], and use matplotlib to visualizing photo data. If we want to display some element of Fashion-MNIST, just like followed:

figure = plt.figure(figsize = (8,8))
cols,rows = 3,3
for i in range(1,cols*rows+1):
    sample_idx = torch.randint(len(training_data),size=(1,)).item()
    img,label = training_data[sample_idx]
    figure.add_subplot(rows,cols,i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(),cmap="gray")
plt.show()

2.3 Creating a Custom Dataset

Sometimes, prepared datasets can’t satisfy our need, so we can create a custom dataset which is subclass of Dataset. A custom Dataset class must implement three functions:__init__,__len__,__getitem__.

  1. __init__

    This function will run when instantiating the Dataset object. We can initialize a directory containing file and both transforms.

  2. __len__

    This function should return the number of samples.

  3. __getitem__

    This function will returns a sample from the dataset at the given index. Some time we should return the element transformed.

So if we need to create a custom dataset, we can do as followed:

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

2.4 Preparing data for training

The Dataset retrieves our dataset’s feature and labels, but while we training a model, we need the pass samples in “minibatches”, reshuffle the data in every phase to reduce model overfitting. Sometimes we also need to use Python’s multiprocessing to speed up our train. In order to do all of below, we can use DataLoader class.

from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

After loaded that dataset into the DataLoader, we can iterate through all data. Each iteration below return a batch(containing batch_size feature) of train_features and train_labels. If shuffle=True the data will shuffle after we iterate over all batches.

# get image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]

2.5 Transforms

Sometimes data is not at the struct required for training machine learning algorithms, so we should transforms the data to more suitable struct.

For example the datasets in TorchVision always have two parameters(transform:modify the features, target_transform:modify the labels) that accept the transformation logic(offered intorchvision.transforms module).

# we use FashionMNIST which feature are PIL Image format and the labels are integers
# we needed to change feature as normalized tensors and labels as one-hot encoded tensors.
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(), 
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor:converts a PIL image or NumPy array into a Tensor,scales the image’s pixel intensity values in the range [0., 1.]

Lambda Transforms:Lambda transforms apply any user-defined lambda function.