Custom Dataset, DataLoader - Paxoo/PyTorch-Best_Practices GitHub Wiki
PyTorch brings the ability to easily craft a custom Dataset object which can then be used with the built-in DataLoader to feed data when training a model.
Table of Content
- Custom Dataset Fundamentals
- Using Torchvision Transforms
- Using DataLoader
- Best Practice
Custom Dataset Fundamentals
A dataset must contain the following functions to be used by DataLoader later on.
__init__()function, the initial logic happens here, assigning transforms, filtering data, etc.,__getitem__()returns the data and the labels.__len__()returns the count of samples your dataset has.
from torch.utils.data.dataset import Dataset
class MyCustomDataset(Dataset):
def __init__(self, ...):
# stuff
def __getitem__(self, index):
# stuff
#_getitem__() is not only restricted to return img and label
return (img, label)
def __len__(self):
return count # of how many examples(images?) you have
Using Torchvision Transforms
Transforms are common image transformations. They can be chained together using Compose. All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Deterministic or random transformations applied on the batch of Tensor Images identically transform all the images of the batch. List of all transformations: Link
Example
import torch
from torchvision import datasets, transforms
train_transforms = transforms.Compose([
transforms.RandomCrop(size=32),
transforms.RandomRotation(degrees=90), # careful, it rotates the image with any number within the range of -90-90° (1,2,3,4,5,6,7..)
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_dataset = datasets.CustomDataset(root=PathToImages,
train=True,
transform=train_transforms)
Use the following code for fixed seed.
import torch
torch.manual_seed(17)
Using DataLoader
While the Dataset class is a nice way of containing data systematically, it seems that in a training loop, we will need to index or slice the dataset's samples list. This is no better than what we would do for a typical list or NumPy matrix. Rather than going down that route, PyTorch supplies another utility function called the DataLoader which acts as a data feeder for a Dataset object.
- Setting
num_workers > 0enables asynchronous data loading and overlap between the training and data loading. num_workers should be tuned depending on the workload, CPU, GPU, and location of training data. - DataLoader accepts
pin_memoryargument, which defaults to False. When using a GPU it’s better to set pin_memory=True, this instructs DataLoader to use pinned memory and enables faster and asynchronous memory copy from the host to the GPU. - Setting
prefetch_factor > 0enables each worker to prefetch n samples in advance (default = 2). It only start with the epoch, when the iterator is created. Meaning, when you do:for sample in dataloader: persistent_workersTrue allows to maintain the workers Dataset instances alive. The data loader will not shutdown the worker processes after a dataset has been consumed once. What are the (dis) advantages of persistent_workers
Example:
train_data_loader = data.DataLoader(dataset=train_dataset,
batch_size = args.batchsize,
shuffle=False,
sampler=RandomSampler(train_dataset),
num_workers=args.num_workers,
pin_memory=True,
prefetch_factor=2,
drop_last=True)
Best Practice
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import random
class EBiMADataset(Dataset):
def __init__(self, root_dir, mode='train', rotation=False, transform=None):
self.blur_list = []
self.sharp_list = []
self.root_dir = root_dir
self.mode = mode
self.transform = transform
self.rotation = rotation
self.scan()
def scan(self):
# scan for images
def __len__(self):
return len(self.blur_list)
def __getitem__(self, idx):
blur_image = Image.open(self.blur_list[idx]).convert('RGB')
if len(self.sharp_list) > 0:
sharp_image = Image.open(self.sharp_list[idx]).convert('RGB')
if self.rotation:
degree = random.choice([0, 90, 180, 270])
blur_image = transforms.functional.rotate(blur_image, degree)
sharp_image = transforms.functional.rotate(sharp_image, degree)
if self.transform:
blur_image = self.transform(blur_image)
sharp_image = self.transform(sharp_image)
return {'blur_image': blur_image, 'sharp_image': sharp_image}
def main():
train_dataset = EBiMADataset(
root_dir = 'Y:\\Institute\\IMFAA\\Messdaten\\05_EBiMA_Aufnahmen\\',
mode = "train",
rotation=True, # own rotation function, add custom transformation in Custom Dataset
transform = transforms.Compose([
transforms.RandomCrop(size=256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
train_dataloader = DataLoader(dataset=train_dataset,
batch_size = args.batchsize,
shuffle=False,
sampler=RandomSampler(train_dataset),
num_workers=args.num_workers,
pin_memory=True,
prefetch_factor=2,
drop_last=True)
for epoch in range(start_epoch, args.epochs):
if args.training = True:
model.train()
tq = tqdm(train_dataloader, ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
for iteration, data in enumerate(tq):
img, label = data
img = img.cuda()
label = label.cuda()
...