How to Write a Custom Dataset and DataLoader
• publicTable of contents
This code is all available in my GitHub.
Datasets & DataLoaders
PyTorch’s abstract class representing a dataset is Dataset
in torch.util.data
. Combined with PyTorch’s DataLoader
, Dataset
allows us to handle our data efficiently.
We’ll write a custom dataset class that inherits from Dataset
since ImageFolder
doesn’t quite work for my purposes. It also seems like a good skill to know because data isn’t always organized asImageFolder
requires.
Dataset Structure
We’ll use the Kaggle Cats and Dogs data, which has this structure:
root
├── train
│ ├── cat.0.png
│ ├── cat.1.png
│ ├── cat.2.png
│ ├── dog.0.png
│ └── dog.1.png
├── test1
│ └── 1.png
│ └── 2.png
│ └── 3.png
You’ll have to modify my code to accommodate the structure of your data.
Sometimes, datasets have an annotations file for their labels. In this case, our training set’s file name is our label. Our test set will have id’s that are their file names, too. Remembering this is important for our custom class.
Custom Dataset Class
Note that our dataset class should implement __len__()
, which returns the length of the dataset, and must implement __getitem__()
, which returns a data sample.
Dependencies
First, let’s import some dependencies.
import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from pathlib import Path
Constructor
Now, we can create out class and a constructor.
class CatsAndDogsDataset(Dataset):
def __init__(self, target_directory, transform=None):
self.paths = list(Path(target_directory).glob('*.*.jpg'))
self.transform = transform
self.classes = sorted(list(set(map(self.get_label, self.paths))))
We want to keep track of all the paths since they contain information about the label and index. .glob()
allows us to select all files in the target directory that has our required labeling. We call list()
on this result to return a list of paths.
We also want to keep track of the image transforms we’ll do. More on this below.
The classes initialization looks complicated, but it’s not bad. map(function, iterable)
will return a map object that has applied the function to each item in the iterable. In our case, we want to have a list of all the classes in our dataset (dog and cat), so we’ll call our get_label()
method defined below for each item.
Then, we call set(iterable)
, which will turn our map object into a set, by definition getting rid of duplicate elements. Finally, we call list()
and sorted()
, which turn the set into a list, then sorts it.
get_label()
@staticmethod
def get_label(path):
filename = str(path.name)
label = filename.split('.')[0]
return label
Let’s define a static method to get the label given a path object. (We don’t operate on strings because we’ve aptly defined an instance variable of a list of paths.) The method is static because we don’t want a get_label()
method that requires an instantiation of the class. Rather, it’s useful to get the label given a path if we want to, without the dataset instance.
Recall that our images have the format label.idx.jpg
.
We call path.name
to get our filename without the extra directory information (i.e. ./train). Then, we split on the .
which returns a list. The first item of the list is the label, which we return.
load_image()
def load_image(self, index):
image_path = self.paths[index]
return Image.open(image_path)
We want to retrieve any image given an index. We find the corresponding path in our instance variable paths
and return an Image
object that represents that image. We can convert this Image
object to a NumPy array using np.array(img)
.
len()
def __len__(self):
return len(self.paths)
This method is required and self-explanatory.
getitem()
def __getitem__(self, index):
img = self.load_image(index)
class_name = self.get_label(self.paths[index])
class_idx = self.classes.index(class_name)
if self.transform:
return self.transform(img), class_idx
else:
return img, class_idx
We have to implement this method. First, we call our load_image()
method from earlier. We’ll need to return it, optionally transforming it beforehand.
To find the correct label of this image, we call get_label()
from earlier.
We call self.classes.index(class_name)
to find the index of class_name
. This is because our model can’t interpret the string “cat” or “dog.” Rather, we want a numerical representation, in this case 0 or 1.
Transforms in getitem()
Before the last lines of __getitem__()
make sense, we need to understand what a transform is. The images in the dataset may not be the same dimension. They may be distorted. If we want our model to learn from meaningful data, we should ensure our training data is uniform.
Transforms literally transform our image.
if self.transform:
will return True
if self.transform
isn’t none. If there is a transform the client wants to apply, we’ll return the transformed image and the class index.
If not, simply return the image and class index.
We’re done coding our dataset class! Now, let’s define our transforms, create our DataLoader
, and test it on some images.
Transforms & DataLoader
train_transforms = transforms.Compose([
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
We crop our image to be a 224 by 224 pixel image (I chose 224 because I’m feeding this into AlexNet). Then, we convert the Image
object to a tensor and normalize it so that values fall in [0,1] for faster learning. Normalize()
only operates on tensors.
train_dataset = CatsAndDogsDataset('./train', transform=train_transforms)
Let’s instantiate our dataset, too.
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
And our dataloader. I arbitrarily set batch_size=4
. Shuffling seems nice, too.
Testing Dataset and DataLoader
Let’s see if we did everything right. Plotting our images and printing our labels, we get
Look at our cute, properly labeled cats and dogs!
Sources
- PyTorch ImageFolder vs. Custom Dataset from single folder
- A Comprehensive Guide to the DataLoader Class and Abstractions in PyTorch