切割資料集 - jenhaoyang/ml_blog GitHub Wiki

# coding: utf-8
import torch
import torch.utils.data as data
from torchvision import datasets

# Random split
train_set_size = int(len(dataset) * 0.7)
valid_set_size = int(len(dataset) * 0.2)
test_set_size = len(dataset) - train_set_size - valid_set_size
train_set, valid_set, test_set = data.random_split(dataset, [train_set_size, valid_set_size, test_set_size], generator=torch.Generator().manual_seed(42))

# After
print('='*30)
print('Total data set:', len(dataset))
print('Train data set:', len(train_set))
print('Valid data set:', len(valid_set))
print('Test data set:', len(test_set))

參考:
https://clay-atlas.com/blog/2020/09/27/pytorch-cn-random-split-dataset/#comment-2248