マルチタスク学習のサンプル - you1025/my_something_flagments GitHub Wiki
マルチタスク学習のサンプルコードを残しておく。
モデル定義と予測〜ロスの算出のとこがポイント。
MNIST データを使った通常の 10 クラス(0〜9)識別に丸みのある数字(0, 6, 8, 9)を識別するタスクを加えたマルチタスクを試してみる。
※丸みのある数字
に特に意味は無くあくまで実験用(書籍のサンプルを読んだ時も意味不明だったw)
参考
- PyTorchによる物体検出 の 1.7 複雑なネットワークの学習
サンプルコード
モデル定義
feature_map + fc_10
と feature_map + fc_2
の 2 系統の出力が必要なため forward は不要。
import torch
class CNN(torch.nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 特徴マップの生成層
self.feature_map = torch.nn.Sequential(
torch.nn.Conv2d(1, 20, 5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Conv2d(20, 50, 5),
torch.nn.ReLU(),
torch.nn.Dropout(0.5),
torch.nn.Flatten()
)
# 10 カテゴリ用
self.fc_10 = torch.nn.Linear(3200, 10)
# 2 カテゴリ用
self.fc_2 = torch.nn.Linear(3200, 2)
学習コード
ポイントは特徴マップを取得した後にタスク毎の出力でロスを算出し、2 つのロスを足し合わせて 1 つのロスとする部分。
loss.backward()
でタスク毎の全結合層にロスが流れて学習を実行する。
import numpy as np
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 変換の定義
transform = transforms.Compose([
transforms.ToTensor()
])
# 学習・テスト用のデータセットを取得
train_set = datasets.MNIST("./mnistdata", train=True, transform=transform, download=True)
test_set = datasets.MNIST("./mnistdata", train=False, transform=transform, download=True)
# ミニバッチ用の DataLoader を作成
batch_size = 256
train_loader = DataLoader(train_set, batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size, shuffle=False, num_workers=2)
model = CNN().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
criterion = torch.nn.CrossEntropyLoss()
# 学習の実行
for epoch in range(10):
# ミニバッチ毎に学習を実施
for (inputs, labels) in train_loader:
labels_10 = labels
# apply_ は元オブジェクトを変更するので clone が必要(labels オブジェクトが変更されてしまうのを避ける)
labels_2 = labels.clone().apply_(lambda label: 1 if label in [6, 8, 9] else 0)
# 特徴マップを取得
feature_map = model.feature_map(inputs.to(device))
# タスク毎に出力を取得
outputs_10 = model.fc_10(feature_map)
outputs_2 = model.fc_2(feature_map)
# タスク毎にロスを計算
loss_10 = criterion(outputs_10, labels_10.to(device))
loss_2 = criterion(outputs_2, labels_2.to(device))
# ロスをまとめる
loss = loss_10 + loss_2
optimizer.zero_grad()
loss.backward()
optimizer.step()
まとめ
- ミニバッチ中に 2 カテゴリ用の教師ラベルを作成しているところは Dataset をちゃんと作成する事で対応した方が良いかと
- 特徴マップの生成層へ流れ込むロスが大きくなるはずなので学習の進行が変わる事の影響についてはちょっと疑問
- ロスの加算時に重みを考慮するやり方もある模様 (深層学習入門:画像分類(4)マルチタスク学習)