Transformer Code Breakdown block by block - beyondnlp/nlp GitHub Wiki
source : https://www.k-a.in/transformers.html
- embedding class
import torch
import torch.nn as nn
import math
class InputEmbeddings(nn.Module):
# μμ±μλ λͺ¨λΈ μ°¨μκ³Ό μ΄ν ν¬κΈ°λ₯Ό λ§€κ°λ³μλ‘ λ°μ΅λλ€.
def __init__(self, d_model:int, vocab_size:int):
# λͺ¨λΈμ dimκ³Ό vocab ν¬κΈ°λ₯Ό ν΄λμ€ μμ±μΌλ‘ μ μ₯ν©λλ€.
self.d_model = d_model
self.vocab_size = vocab_size
# ν ν° μΈλ±μ€λ₯Ό d_model μ°¨μ 벑ν°μ λ§€ννλ μλ² λ© κ³μΈ΅μ μμ±ν©λλ€.
self.embedding = nn.Embedding(vocab_size,d_model)
# μ
λ ₯ xλ₯Ό μ²λ¦¬νλ μ λ°©ν₯ ν¨μ€ λ°©μ.
def forward(self,x):
# μλ² λ©μ μ μ©νκ³ κ³±ν©λλ€. λ
Όλ¬Έμ κΈ°μ λ λλ‘ λΆμ°μ μμ μ μΌλ‘ μ μ§ν©λλ€.
return self.embedding(x) * math.sqrt(self.d_model)
- positional encoding class
# μλ² λ©μ μμΉ μ 보λ₯Ό μΆκ°νκΈ° μν ν΄λμ€μ
λλ€.
class PositionalEncoding(nn.Module):
# μμ±μλ λͺ¨λΈ μ°¨μ, μ΅λ μνμ€ κΈΈμ΄, λλ‘μμ λΉμ¨μ λ§€κ°λ³μλ‘ λ°μ΅λλ€.
def __init__(self, d_model:int, seq_len:int, dropout:float) -> None:
# μ°¨μμ ν΄λμ€ μμ±μΌλ‘ μ μ₯ν©λλ€.
self.d_model = d_model
self.seq_len = seq_len
# λλ‘μμ λ μ΄μ΄λ₯Ό μμ±ν©λλ€
self.droput = nn.Dropout(dropout)
# μμΉ μΈμ½λ©μ μν ν
μλ₯Ό 0μΌλ‘ μ΄κΈ°νν©λλ€.
pe = torch.zeros(seq_len,d_model)
position = torch.arange(0,seq_len,dtype=torch.float).unsqueeze(1)
# 0λΆν° seq_len-1κΉμ§μ μμΉ μΈλ±μ€λ₯Ό κ°λ μ΄ λ²‘ν°λ₯Ό μμ±ν©λλ€.
div_term = torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
# λ
Όλ¬Έ 곡μμ λ°λΌ μ¬μΈ ν¨μμ λν λλμ
νμ μμ±ν©λλ€.
pe[:,0::2] = torch.sin(position*div_term)
pe[:,1::2] = torch.cos(position*div_term)
# μμΉ μΈμ½λ©μ νμ μΈλ±μ€μ μ½μ¬μΈμ μ μ©ν©λλ€.
pe = pe.unsqueeze(0)
# μμΉ μΈμ½λ© ν
μλ₯Ό λ²νΌ(λ§€κ°λ³μκ° μλ μꡬ μν)λ‘ λ±λ‘ν©λλ€.
self.register_buffer('pe',pe)
# μ
λ ₯ xμ λν μλ°©ν₯ ν¨μ€ λ°©λ².
def forward(self,x):
# μ
λ ₯μ μμΉ μΈμ½λ©μ μΆκ°νκ³ , μ
λ ₯ μνμ€ κΈΈμ΄μ λ§κ² νΈλ¦¬λ°νλ©°, κ·ΈλλμΈνΈλ λΉνμ±νλ©λλ€.
x = x + (self.pe[:,:x.shape[1],:]).requires_grad_(False)
# μ
λ ₯κ³Ό μμΉ μΈμ½λ©μ ν©κ³μ λλ‘μμμ μ μ©ν©λλ€.
return self.dropput(x)
- multi head attention block class
# λ©ν°ν€λ μ΄ν
μ
λ©μ»€λμ¦μ ꡬνν©λλ€.
class MultiHeadAttentionBlock(nn.Module):
# λͺ¨λΈ μ°¨μ, ν€λ μ, μ€λ νλ½λ₯ μ κ°μΆ μμ±μμ
λλ€.
def __init__(self,d_model:int,h:int,dropout:float)->None:
# μ°¨μμ μμ±μΌλ‘ μ μ₯ν©λλ€.
self.d_model = d_model
self.h = h
# λͺ¨λΈμ μ°¨μμ΄ ν€λ κ°μλ‘ λλμ΄ λ¨μ΄μ§λμ§ νμΈν©λλ€.
assert d_model % h == 0, "d_model is not divisible by h"
self.d_k = d_model // h
# νΈν₯ μλ 쿼리 벑ν°μ λν μ ν ν¬μ.
self.w_q = nn.Linear(d_model,d_model,bias=False)
self.w_k = nn.Linear(d_model,d_model,bias=False)
self.w_v = nn.Linear(d_model,d_model,bias=False)
self.w_o = nn.Linear(d_model,d_model,bias=False)
self.dropout = nn.Dropout(dropout)
# νμ₯λ μ κ³± μ΄ν
μ
μ ꡬννλ μ μ λ©μλμ
λλ€.
@staticmethod
def attention(query,key,value,mask,dropout:nn.Dropout):
# ν€/쿼리μ μ°¨μμ κ°μ Έμ΅λλ€.
d_k = query.shape[-1]
# νλ ¬ κ³±μ
κ³Ό μ€μΌμΌλ§μ μ¬μ©νμ¬ μ£Όμ μ μλ₯Ό κ³μ°ν©λλ€.
attention_scores = (query @ key.transpose(-2,-1))/math.sqrt(d_k)
# λ§μ€ν¬λ μμΉλ₯Ό μμ 무νλ(λ§€μ° κ·Έλ μ§λ§ μ€μ λ‘λ μλ) κ°μΌλ‘ μ€μ νμ¬ λ§μ€ν¬λ₯Ό μ μ©ν©λλ€.
if mask is not None: attention_scores.masked_fill_(mask==0,-1e9)
attention_scores = attention_scores.softmax(dim=-1)
# attention scoreμ λλ‘μμμ μ μ©ν©λλ€.
if dropout is not None: attention_scores = dropout(attention_scores)
return (attention_scores @ value), attention_scores
def forward(self,q,k,v,mask):
query = self.w_q(q)
key = self.w_k(k)
value = self.w_v(v)
# λ©ν°ν€λ μ²λ¦¬λ₯Ό μν΄ ν
μλ₯Ό μ¬κ΅¬μ±νκ³ μ μΉν©λλ€.
query = query.view(query.shape[0],query.shape[1],self.h,self.d_k).transpose(1,2)
key = key.view(key.shape[0],key.shape[1],self.h,self.d_k).transpose(1,2)
value = value.view(value.shape[0],value.shape[1],self.h,self.d_k).transpose(1,2)
# attentionμ κ³μ°νκ³ μ μλ₯Ό μ μ₯ν©λλ€.
x,self.attention_scores = MultiHeadAttentionBlock.attention(query,key,value,mask,self.dropout)
# μΆλ ₯λ¬Όμ μλ ν¬κΈ°λ‘ λ€μ μ‘°μ ν©λλ€.
x = x.transpose(1,2).contiguous().view(x.shape[0],-1,self.h*self.d_k)
return self.w_o(x)
- layer norm class
λ€νΈμν¬ νμ±νλ₯Ό μμ ννκΈ° μν΄ λ μ΄μ΄ μ κ·νλ₯Ό ꡬνν©λλ€.
class LayerNormalization(nn.Module): # 0μΌλ‘ λλλ κ²μ λ°©μ§νκΈ° μν΄ κΈ°λ₯ κ°μμ μμ μ‘μ€λ‘ μ μ¬μ©ν μμ±μμ λλ€. def init(self,features: int,eps:float=10**-6) -> None:
self.eps = eps
# νμ΅ κ°λ₯ν μ€μΌμΌλ§ λ§€κ°λ³μλ 1λ‘ μ΄κΈ°νλ©λλ€.
self.alpha = nn.Parameter(torch.ones(features))
# νμ΅ κ°λ₯ν νΈν₯ λ§€κ°λ³μλ 0μΌλ‘ μ΄κΈ°νλ©λλ€.
self.bias = nn.Parameter(torch.zeros(features))
# μ
λ ₯ xμ λν μλ°©ν₯ ν¨μ€ λ°©λ².
def forward(self,x):
# νΉμ± μ°¨μ μ 체μ κ±Έμ³ νκ· μ κ³μ°ν©λλ€.
mean = x.mean(dim = -1, keepdim = True)
# νΉμ± μ°¨μ μ 체μ κ±Έμ³ νμ€ νΈμ°¨λ₯Ό κ³μ°ν©λλ€.
std = x.std(dim = -1, keepdim = True)
# μ
λ ₯μ μ κ·ννκ³ , ν¬κΈ° μ‘°μ λ° νΈν₯μ μ μ©ν©λλ€.
return self.alpha * (x-mean)/(std+self.eps) + self.bias
* feedforward block
class FeedForwardBlock(nn.Module):
def __init__(self,d_model:int,d_ff:int,dropout:float) -> None:
# d_modelμμ d_ff μ°¨μμΌλ‘μ 첫 λ²μ§Έ μ ν λ³νμ
λλ€.
self.linear_1 = nn.Linear(d_model,d_ff)
# μ κ·νλ₯Ό μν λλ‘μμ λ μ΄μ΄.
self.dropout = nn.Dropout(dropout)
# d_ffμμ d_model μ°¨μμΌλ‘ λ€μ λμκ°λ λ λ²μ§Έ μ ν λ³νμ
λλ€.
self.linear_2 = nn.Linear(d_ff,d_model)
def forward(self,x):
# 첫 λ²μ§Έ μ ν λ³ν, ReLU νμ±ν, λλ‘μμ, λ λ²μ§Έ μ ν λ³νμ μ μ©ν©λλ€.
return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
-
Residual connection class class ResidualConnection(nn.Module): def init(self,features: int,dropout:float) -> None:
self.dropout = nn.Dropout(dropout) self.norm = LayerNormalization(features)
def forward(self,x,sublayer): # λ μ΄μ΄ νμ€, νμ λ μ΄μ΄, λλ‘μμ, μμ¬ μ°κ²°μ μ μ©ν©λλ€. return x + self.dropout(sublayer(self.norm(x)))
* encoder block
class EncoderBlock(nn.Module):
def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features,dropout) for _ in range(2)])
def forward(self,x,src_mask):
x = self.residual_connections[0](x,lambda x: self.self_attention_block(x,x,x,src_mask))
x = self.residual_connections[1](x, self.feed_forward_block)
return x