đź“šUsing Custom Dataset in PyTorch
In order to decouple dataset code and model code, PyTorch provides two data primitives: torch.utils.data.DataLoader
and torch.utils.data.Dataset
. Dataset
stores the samples and their corresponding labels, and DataLoader
wraps an iterable around the Dataset
to enable easy access to the samples.
Load a Dataset
If we want to use Data, we must have Data first. Fortunately, PyTorch domain libraries provide a number of pre-loaded datasets. All of them is the subclass of theDataset
. Now we use Fashion-MNIST, one of them, to show how to load a dataset.
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import import ToTensor
training_data = datasets.FashionMNIST(
root = "data", # the path where the data is stored
train = True, # specifies training or test dataset
download = True, # downloads the data from internet if not available at root
transform = ToTensor() # specify the feature and label transformation
)
2.2 Iterating and Visualizing the Dataset
We can index Datasets
manually like a list:training_data[index]
, and use matplotlib
to visualizing photo data. If we want to display some element of Fashion-MNIST, just like followed:
figure = plt.figure(figsize = (8,8))
cols,rows = 3,3
for i in range(1,cols*rows+1):
sample_idx = torch.randint(len(training_data),size=(1,)).item()
img,label = training_data[sample_idx]
figure.add_subplot(rows,cols,i)
plt.title(label)
plt.axis("off")
plt.imshow(img.squeeze(),cmap="gray")
plt.show()
2.3 Creating a Custom Dataset
Sometimes, prepared datasets can’t satisfy our need, so we can create a custom dataset which is subclass of Dataset
. A custom Dataset class must implement three functions:__init__
,__len__
,__getitem__
.
-
__init__
This function will run when instantiating the Dataset object. We can initialize a directory containing file and both transforms.
-
__len__
This function should return the number of samples.
-
__getitem__
This function will returns a sample from the dataset at the given index. Some time we should return the element transformed.
So if we need to create a custom dataset, we can do as followed:
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
2.4 Preparing data for training
The Dataset
retrieves our dataset’s feature and labels, but while we training a model, we need the pass samples in “minibatches”, reshuffle the data in every phase to reduce model overfitting. Sometimes we also need to use Python’s multiprocessing
to speed up our train. In order to do all of below, we can use DataLoader
class.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
After loaded that dataset into the DataLoader, we can iterate through all data. Each iteration below return a batch(containing batch_size
feature) of train_features
and train_labels
. If shuffle=True
the data will shuffle after we iterate over all batches.
# get image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
2.5 Transforms
Sometimes data is not at the struct required for training machine learning algorithms, so we should transforms the data to more suitable struct.
For example the datasets in TorchVision always have two parameters(transform
:modify the features, target_transform
:modify the labels) that accept the transformation logic(offered intorchvision.transforms
module).
# we use FashionMNIST which feature are PIL Image format and the labels are integers
# we needed to change feature as normalized tensors and labels as one-hot encoded tensors.
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
ToTensor
:converts a PIL image or NumPy array into a Tensor,scales the image’s pixel intensity values in the range [0., 1.]
Lambda Transforms
:Lambda transforms apply any user-defined lambda function.