Pytorch에서 데이터를 업로드하는 세가지 방법 - DonghoonPark12/ssd.pytorch GitHub Wiki

Pytorch에서 데이터를 업로드하는 방법에는 세가지가 있다.

  1. 아래 코드 처럼 Dataset을 상속받는 새로운 custom 데이터셋 클래스를 정의하거나
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor 

class VOCDataset(Dataset) # 추상 클래스
   def __init__(root, transform)

train_dataset = VOCDataset(root='../../data',
                           train=True,
                           transform=ToTensor())
  1. torchvision.datasets 에 학습하고자 하는 데이터셋이 있는 경우 아래 코드 처럼 바로 가져다 쓴다.
from torchvision.datasets import COCO
from torchvision.transforms import ToTensor 

train_dataset = COCO(root='../../data',
                     train=True,
                     transform=ToTensor(),
                     download=False)
  1. 혹은 이미지가 클래스 폴더별 png형식으로 저장되에 있는 경우(예시) 아래의 방식을 따른다.
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor 

train_dir = '../../data/train'
train_dataset = ImageFolder(train_dir, transform=ToTensor())

# 혹은 고급 스럽게 다음과 같이 할 수도 있다.

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: ImageFolder(os.path.join(data_dir, x), data_transforms[x])    for x in ['train', 'val']}

하지만 위의 3가지 방식에서 공통적으로 등장한 train_dataset객체는 모두 torch.utils.data.Dataset의 서브 클래스로 부터 생성된 점은 동일하며 아래의 DataLoader에 모두 인수로 들어갈 수 있다.

from torch.utils.data import DataLoader
data_loader = DataLoader(root=train_dataset, 
                         batch_size=32, 
                         shuffle=True, 
                         num_workers=args.nTreads)

# 혹은 고급스럽게 다음과 같이 할 수도 있다.

dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,  shuffle=True, num_workers=4)    for x in ['train', 'val']}

[Refer]
https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html -끝-