모델 Class 코드 내 헷갈렸던 init, super, 상속, optimizer 개념 적립 - Songwooseok123/Study_Space GitHub Wiki

  • 항상 헷갈리고 정확히 몰랐던 개념들
  • 강다솔 저자의 책 “한 권으로 끝내는 실전 LLM 파인튜닝”에 있는 gpt 학습 코드 활용
import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size = 32
block_size = 8
max_iteration = 50000
eval_interval = 300
learning_rate = 1e-2
device = "cuda" if torch.cuda.is_available() else "cpu"
eval_iteration = 200

def batch_function(mode):
    dataset = train_dataset if mode == "train" else test_dataset
    idx = torch.randint(len(dataset) - block_size, (batch_size,))
    x = torch.stack([dataset[index:index+block_size] for index in idx])
    y = torch.stack([dataset[index+1:index+block_size+1] for index in idx])
    x, y = x.to(device), y.to(device) # .to 를 추가
    return x, y

@torch.no_grad()
def compute_loss_metrics():
    out = {}
    model.eval()
    for mode in ["train", "eval"]:
        losses = torch.zeros(eval_iteration)
        for k in range(eval_iteration):
            inputs, targets = batch_function(mode)
            logits, loss = model(inputs, targets)
            losses[k] = loss.item()
        out[mode] = losses.mean()
    model.train()
    return out

class semiGPT(nn.Module):
    def __init__(self, vocab_length):
        super().__init__()
        self.embedding_token_table = nn.Embedding(vocab_length, vocab_length)

    def forward(self, inputs, targets=None):
        logits = self.embedding_token_table(inputs)
        if targets is None:
            loss = None
        else:
            batch, seq_length, vocab_length = logits.shape
            logits = logits.view(batch * seq_length, vocab_length)
            targets = targets.view(batch*seq_length)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, inputs, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self.forward(inputs)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_inputs = torch.multinomial(probs, num_samples=1)
            inputs = torch.cat((inputs, next_inputs), dim=1)
        return inputs

model = semiGPT(ko_vocab_size).to(device) # 여기서 model이 instance임 ("semiGPT 클래스로부터 만들어진 model 인스턴스"라고 표현함) 
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for step in range(max_iteration):
    if step % eval_interval == 0 :
        losses = compute_loss_metrics()
        print(f'step : {step}, train loss : {losses["train"]:.4f}, val loss : {losses["eval"]:.4f}')

    example_x, example_y = batch_function("train")
    logits, loss = model(example_x, example_y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

inputs = torch.zeros((1,1), dtype=torch.long, device=device)
print(token_decode(model.generate(inputs, max_new_tokens=100)[0].tolist()))

Init 메서드

  • 이닛 메서드는 클래스 인스턴스가 생성될 때, 자동으로 호출되는 메서드다. 이를 파이썬 생성자라고 부른다. class 인스턴스를 만들 때 기본설정하는 것과 같다.
    • semiGPT는 init 함수에서 vocab_length를 변수로 받는 classs임.
  • 따라서 모델을 정의할 때 다음과 같이 정의함
    • model = semiGPT(ko_vocab_size).to(device)
    • model_2 = semiGPT(ko_vocab_size).to(device)
    • model과 model_2는 init을 설정해줬기 대문에 서로 독립적으로 작동할 수 있다.
    • super 부분은 부모 클래스의 init 메서드를 호출하는 것이다.
    • super는 물려받은 부모 클래스의 메서드에 접근할 수 있게 해줌
    • super().init() 호출은 nn.Module의 생성자를 호출한다. 이는 nn.Module의 모든 기능과 속성을 제대로 초기화하고 사용할 수 있게 하는 단계임.

forward 메서드

  • forward 메서드는 실제로 데이터가 모델을 통과하는 과정
    • logits, loss = model(example_x, example_y)
class semiGPT(nn.Module): 
    def __init__(self, vocab_length):
        super().__init__()
        self.embedding_token_table = nn.Embedding(vocab_length, vocab_length)

    def forward(self, inputs, targets=None):
        logits = self.embedding_token_table(inputs)
        if targets is None:
            loss = None
        else:
            batch, seq_length, vocab_length = logits.shape
            logits = logits.view(batch * seq_length, vocab_length)
            targets = targets.view(batch*seq_length)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
  • 첫번째 변수는 최적화할 모델의 매개변수
    • lora 학습시 이 부분을 수정했던 것 같음.