class TransformerBlock(nn.Module):
"""
A straightforward transformer block.
"""
def __init__(self, emb, heads, mask, seq_length, ff_hidden_mult=4, dropout=0.0, attention_type='default',
pos_embedding=None, sa_kwargs={}):
super().__init__()
if attention_type == 'default':
self.attention = SelfAttention(emb, heads=heads, mask=mask, **sa_kwargs)
elif attention_type == 'alt':
self.attention = SelfAttentionAlt(emb, heads=heads, mask=mask)
elif attention_type == 'wide':
self.attention = SelfAttentionWide(emb, heads=heads, mask=mask)
elif attention_type == 'gpt2':
self.attention = SelfAttentionGPT2(emb, heads=heads, mask=mask)
elif attention_type == 'narrow':
self.attention = SelfAttentionNarrow(emb, heads=heads, mask=mask)
elif attention_type == 'relative':
assert pos_embedding is not None
self.attention = SelfAttentionRelative(emb, heads=heads, mask=mask, pos_embedding=pos_embedding)
else:
raise Exception(f'Self-attention type {type} not recognized.')
self.mask = mask
self.norm1 = nn.LayerNorm(emb)
self.norm2 = nn.LayerNorm(emb)
self.ff = nn.Sequential(
nn.Linear(emb, ff_hidden_mult * emb),
nn.ReLU(),
nn.Linear(ff_hidden_mult * emb, emb)
)
self.do = nn.Dropout(dropout)
def forward(self, x):
attended = self.attention(x)
x = self.norm1(attended + x)
x = self.do(x)
fedforward = self.ff(x)
x = self.norm2(fedforward + x)
x = self.do(x)
return x