Receive latest posts
Great! Please check your inbox and click the confirmation link.
Sorry, something went wrong. Please try again.

How to Write a Custom Dataset and DataLoader

public
4 min read

Table of contents

This code is all available in my GitHub.

Datasets & DataLoaders

PyTorch’s abstract class representing a dataset is Dataset in torch.util.data. Combined with PyTorch’s DataLoader, Dataset allows us to handle our data efficiently.

We’ll write a custom dataset class that inherits from Dataset since ImageFolder doesn’t quite work for my purposes. It also seems like a good skill to know because data isn’t always organized asImageFolder requires.

Dataset Structure

We’ll use the Kaggle Cats and Dogs data, which has this structure:

 root
├── train
│   ├── cat.0.png
│   ├── cat.1.png
│   ├── cat.2.png
│   ├── dog.0.png
│   └── dog.1.png
├── test1
│   └── 1.png
│   └── 2.png
│   └── 3.png

You’ll have to modify my code to accommodate the structure of your data.

Sometimes, datasets have an annotations file for their labels. In this case, our training set’s file name is our label. Our test set will have id’s that are their file names, too. Remembering this is important for our custom class.

Custom Dataset Class

Note that our dataset class should implement __len__() , which returns the length of the dataset, and must implement __getitem__() , which returns a data sample.

Dependencies

First, let’s import some dependencies.

import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from pathlib import Path

Constructor

Now, we can create out class and a constructor.

class CatsAndDogsDataset(Dataset):

def __init__(self, target_directory, transform=None):
        self.paths = list(Path(target_directory).glob('*.*.jpg'))
        self.transform = transform
        self.classes = sorted(list(set(map(self.get_label, self.paths))))

We want to keep track of all the paths since they contain information about the label and index. .glob()allows us to select all files in the target directory that has our required labeling. We call list() on this result to return a list of paths.

We also want to keep track of the image transforms we’ll do. More on this below.

The classes initialization looks complicated, but it’s not bad. map(function, iterable) will return a map object that has applied the function to each item in the iterable. In our case, we want to have a list of all the classes in our dataset (dog and cat), so we’ll call our get_label() method defined below for each item.

Then, we call set(iterable), which will turn our map object into a set, by definition getting rid of duplicate elements. Finally, we call list() and sorted(), which turn the set into a list, then sorts it.

get_label()

@staticmethod
def get_label(path):
    filename = str(path.name)
    label = filename.split('.')[0]
    return label

Let’s define a static method to get the label given a path object. (We don’t operate on strings because we’ve aptly defined an instance variable of a list of paths.) The method is static because we don’t want a get_label() method that requires an instantiation of the class. Rather, it’s useful to get the label given a path if we want to, without the dataset instance.

Recall that our images have the format label.idx.jpg.

We call path.name to get our filename without the extra directory information (i.e. ./train). Then, we split on the . which returns a list. The first item of the list is the label, which we return.

load_image()

def load_image(self, index):
        image_path = self.paths[index]
        return Image.open(image_path)

We want to retrieve any image given an index. We find the corresponding path in our instance variable paths and return an Image object that represents that image. We can convert this Image object to a NumPy array using np.array(img).

len()

def __len__(self):
        return len(self.paths)

This method is required and self-explanatory.

getitem()

    def __getitem__(self, index):
        img = self.load_image(index)
        class_name = self.get_label(self.paths[index])
        class_idx = self.classes.index(class_name)

        if self.transform:
            return self.transform(img), class_idx
        else:
            return img, class_idx

We have to implement this method. First, we call our load_image() method from earlier. We’ll need to return it, optionally transforming it beforehand.

To find the correct label of this image, we call get_label() from earlier.

We call self.classes.index(class_name) to find the index of class_name. This is because our model can’t interpret the string “cat” or “dog.” Rather, we want a numerical representation, in this case 0 or 1.

Transforms in getitem()

Before the last lines of __getitem__() make sense, we need to understand what a transform is. The images in the dataset may not be the same dimension. They may be distorted. If we want our model to learn from meaningful data, we should ensure our training data is uniform.

Transforms literally transform our image.

if self.transform: will return True if self.transform isn’t none. If there is a transform the client wants to apply, we’ll return the transformed image and the class index.

If not, simply return the image and class index.


We’re done coding our dataset class! Now, let’s define our transforms, create our DataLoader, and test it on some images.

Transforms & DataLoader

train_transforms = transforms.Compose([
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

We crop our image to be a 224 by 224 pixel image (I chose 224 because I’m feeding this into AlexNet). Then, we convert the Image object to a tensor and normalize it so that values fall in [0,1] for faster learning. Normalize() only operates on tensors.

train_dataset = CatsAndDogsDataset('./train', transform=train_transforms)

Let’s instantiate our dataset, too.

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)

And our dataloader. I arbitrarily set batch_size=4 . Shuffling seems nice, too.

Testing Dataset and DataLoader

Let’s see if we did everything right. Plotting our images and printing our labels, we get

Look at our cute, properly labeled cats and dogs!

Sources

Michel Liao

Michel Liao

Boise, Idaho, United States
Hello! I'm a sophomore studying computer science at Princeton. I like reading, rock climbing, and running.