複数モデルを結合したモデルから、単体モデルを取得する - Shinichi0713/LLM-fundamental-study GitHub Wiki
結論
state dictのプレフィックスが、目的のもののみを以下コードのように集める
# 2. encoderのみのパラメータを抽出
encoder_state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")}
その後、目的のモデルでロードする。
# 3. 新しいencoderモデルを作り、パラメータをロード
new_encoder = Encoder()
new_encoder.load_state_dict(encoder_state_dict)
以下に「encoderモデルとdownstreamモデルを結合したモデルのstate_dict」から、「encoderモデルのみのパラメータ」を抽出してencoderモデルにロードするサンプルコードを示します。
サンプルコード
import torch
import torch.nn as nn
# サンプルのEncoderとDownstreamモデル
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
def forward(self, x):
return self.fc1(x)
class Downstream(nn.Module):
def __init__(self):
super().__init__()
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
return self.fc2(x)
# 結合モデル
class FullModel(nn.Module):
def __init__(self, encoder, downstream):
super().__init__()
self.encoder = encoder
self.downstream = downstream
def forward(self, x):
x = self.encoder(x)
x = self.downstream(x)
return x
# --- モデル作成・保存 ---
encoder = Encoder()
downstream = Downstream()
full_model = FullModel(encoder, downstream)
# 例: state_dictの保存
torch.save(full_model.state_dict(), "full_model.pth")
# --- encoderのみのパラメータ抽出 ---
# 1. state_dictのロード
state_dict = torch.load("full_model.pth")
# 2. encoderのみのパラメータを抽出
encoder_state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")}
# 3. 新しいencoderモデルを作り、パラメータをロード
new_encoder = Encoder()
new_encoder.load_state_dict(encoder_state_dict)
# 動作確認
print(new_encoder)
ポイント
state_dict
のキーは、encoder.fc1.weight
のようにencoder.
プレフィックスがついています。- そのため、
k.startswith("encoder.")
で抽出し、replace("encoder.", "")
でプレフィックスを除去します。 - こうしてできた
encoder_state_dict
を、Encoderモデルにload_state_dict()
します。