複数モデルを結合したモデルから、単体モデルを取得する - Shinichi0713/LLM-fundamental-study GitHub Wiki

結論

state dictのプレフィックスが、目的のもののみを以下コードのように集める

# 2. encoderのみのパラメータを抽出
encoder_state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")}

その後、目的のモデルでロードする。

# 3. 新しいencoderモデルを作り、パラメータをロード
new_encoder = Encoder()
new_encoder.load_state_dict(encoder_state_dict)

以下に「encoderモデルとdownstreamモデルを結合したモデルのstate_dict」から、「encoderモデルのみのパラメータ」を抽出してencoderモデルにロードするサンプルコードを示します。


サンプルコード

import torch
import torch.nn as nn

# サンプルのEncoderとDownstreamモデル
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
    def forward(self, x):
        return self.fc1(x)

class Downstream(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc2 = nn.Linear(20, 5)
    def forward(self, x):
        return self.fc2(x)

# 結合モデル
class FullModel(nn.Module):
    def __init__(self, encoder, downstream):
        super().__init__()
        self.encoder = encoder
        self.downstream = downstream
    def forward(self, x):
        x = self.encoder(x)
        x = self.downstream(x)
        return x

# --- モデル作成・保存 ---
encoder = Encoder()
downstream = Downstream()
full_model = FullModel(encoder, downstream)

# 例: state_dictの保存
torch.save(full_model.state_dict(), "full_model.pth")

# --- encoderのみのパラメータ抽出 ---
# 1. state_dictのロード
state_dict = torch.load("full_model.pth")

# 2. encoderのみのパラメータを抽出
encoder_state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")}

# 3. 新しいencoderモデルを作り、パラメータをロード
new_encoder = Encoder()
new_encoder.load_state_dict(encoder_state_dict)

# 動作確認
print(new_encoder)

ポイント

  • state_dictのキーは、encoder.fc1.weightのようにencoder.プレフィックスがついています。
  • そのため、k.startswith("encoder.")で抽出し、replace("encoder.", "")でプレフィックスを除去します。
  • こうしてできたencoder_state_dictを、Encoderモデルにload_state_dict()します。