Custom Dataset, DataLoader - Paxoo/PyTorch-Best_Practices GitHub Wiki

PyTorch brings the ability to easily craft a custom Dataset object which can then be used with the built-in DataLoader to feed data when training a model.

Table of Content

  1. Custom Dataset Fundamentals
  2. Using Torchvision Transforms
  3. Using DataLoader
  4. Best Practice

Custom Dataset Fundamentals

A dataset must contain the following functions to be used by DataLoader later on.

  • __init__() function, the initial logic happens here, assigning transforms, filtering data, etc.,
  • __getitem__() returns the data and the labels.
  • __len__() returns the count of samples your dataset has.
from torch.utils.data.dataset import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff

        #_getitem__() is not only restricted to return img and label
        return (img, label)

    def __len__(self):
        return count # of how many examples(images?) you have

Using Torchvision Transforms

Transforms are common image transformations. They can be chained together using Compose. All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Deterministic or random transformations applied on the batch of Tensor Images identically transform all the images of the batch. List of all transformations: Link

Example

import torch
from torchvision import datasets, transforms

train_transforms = transforms.Compose([
        transforms.RandomCrop(size=32),
        transforms.RandomRotation(degrees=90), # careful, it rotates the image with any number within the range of -90-90° (1,2,3,4,5,6,7..)
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
train_dataset = datasets.CustomDataset(root=PathToImages,
                                 train=True, 
                                 transform=train_transforms)

Use the following code for fixed seed.

import torch
torch.manual_seed(17)

Using DataLoader

While the Dataset class is a nice way of containing data systematically, it seems that in a training loop, we will need to index or slice the dataset's samples list. This is no better than what we would do for a typical list or NumPy matrix. Rather than going down that route, PyTorch supplies another utility function called the DataLoader which acts as a data feeder for a Dataset object.

  • Setting num_workers > 0 enables asynchronous data loading and overlap between the training and data loading. num_workers should be tuned depending on the workload, CPU, GPU, and location of training data.
  • DataLoader accepts pin_memory argument, which defaults to False. When using a GPU it’s better to set pin_memory=True, this instructs DataLoader to use pinned memory and enables faster and asynchronous memory copy from the host to the GPU.
  • Setting prefetch_factor > 0 enables each worker to prefetch n samples in advance (default = 2). It only start with the epoch, when the iterator is created. Meaning, when you do: for sample in dataloader:
  • persistent_workers True allows to maintain the workers Dataset instances alive. The data loader will not shutdown the worker processes after a dataset has been consumed once. What are the (dis) advantages of persistent_workers

Example:

train_data_loader = data.DataLoader(dataset=train_dataset, 
                                    batch_size = args.batchsize,
                                    shuffle=False,
                                    sampler=RandomSampler(train_dataset),
                                    num_workers=args.num_workers,
                                    pin_memory=True,
                                    prefetch_factor=2,
                                    drop_last=True)

Best Practice

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import random

class EBiMADataset(Dataset):
    def __init__(self, root_dir, mode='train', rotation=False, transform=None):
        self.blur_list = []
        self.sharp_list = []
        
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform        
        self.rotation = rotation  
        
        self.scan()

    def scan(self):
        # scan for images
        
    def __len__(self):
        return len(self.blur_list)
    
    def __getitem__(self, idx):
        blur_image = Image.open(self.blur_list[idx]).convert('RGB')
        if len(self.sharp_list) > 0:
            sharp_image = Image.open(self.sharp_list[idx]).convert('RGB')
        
        if self.rotation:
            degree = random.choice([0, 90, 180, 270])
            blur_image = transforms.functional.rotate(blur_image, degree) 
            sharp_image = transforms.functional.rotate(sharp_image, degree)

        if self.transform:
            blur_image = self.transform(blur_image)
            sharp_image = self.transform(sharp_image)

        return {'blur_image': blur_image, 'sharp_image': sharp_image}

def main():
   train_dataset = EBiMADataset(
                    root_dir = 'Y:\\Institute\\IMFAA\\Messdaten\\05_EBiMA_Aufnahmen\\',
                    mode = "train",
                    rotation=True, # own rotation function, add custom transformation in Custom Dataset
                    transform = transforms.Compose([
                                     transforms.RandomCrop(size=256),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ]))

    train_dataloader = DataLoader(dataset=train_dataset, 
                                batch_size = args.batchsize,
                                shuffle=False,
                                sampler=RandomSampler(train_dataset),
                                num_workers=args.num_workers,
                                pin_memory=True,
                                prefetch_factor=2,
                                drop_last=True)

    for epoch in range(start_epoch, args.epochs):    
        if args.training = True:
            model.train()
            tq = tqdm(train_dataloader, ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
            for iteration, data in enumerate(tq):
                 img, label = data
                 img = img.cuda()
                 label = label.cuda()
                 ...