src(来自gpt‐fast的repo) - Nanji-Huaji/hierarchical-speculative-decoding GitHub Wiki
以下内容由AI生成
本项目提供了一个基于 Transformer 架构的高效文本生成工具,支持多种解码策略,包括自回归采样、推测解码和张量并行。它专为高性能生成任务设计,适用于各种自然语言处理应用场景。
def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
print(f"device={device} is not yet suppported")
功能: 根据设备类型进行同步操作。
参数:
-
device
(str
): 设备类型(如"cuda"
、"cpu"
)。
class Transformer:
# 模型定义和方法
...
功能: Transformer 模型类,负责文本生成的核心计算。
主要方法:
-
setup_caches
: 设置缓存。 -
forward
: 模型前向传播。
from tokenizer import get_tokenizer
功能: 获取分词器。
参数:
-
tokenizer_path
: 分词器路径。 -
checkpoint_path
: 模型检查点路径。
返回值:
- 分词器对象。
def multinomial_sample_one_no_sync(probs_sort):
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
功能: 进行多项式采样,无需 CUDA 同步。
参数:
-
probs_sort
: 概率张量。
返回值:
- 采样的 token 索引。
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
功能: 将 logits 转换为概率分布,支持温度控制和 Top-K 过滤。
参数:
-
logits
: 输入的 logits 张量。 -
temperature
: 温度参数。 -
top_k
: Top-K 采样参数。
返回值:
- 转换后的概率张量。
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
功能: 从 logits 中采样下一个 token。
参数:
-
logits
: 输入的 logits 张量。 -
temperature
: 温度参数。 -
top_k
: Top-K 采样参数。
返回值:
- 采样的 token 索引和概率。
def roundup(val, multiplier):
return ((val - 1) // multiplier + 1) * multiplier
功能: 将值向上取整到最近的倍数。
参数:
-
val
: 输入值。 -
multiplier
: 倍数。
返回值:
- 向上取整后的值。
def causal_mask(b, h, q, kv):
return q >= kv
功能: 创建因果掩码。
参数:
-
b
: 批量大小。 -
h
: 注意力头数。 -
q
: 查询位置。 -
kv
: 键值位置。
返回值:
- 因果掩码。
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device=x.device)
logits = model(mask, x, input_pos)
return sample(logits, **sampling_kwargs)[0]
功能: 预填充模型,生成初始 token。
参数:
-
model
: Transformer 模型。 -
x
: 输入张量。 -
input_pos
: 输入位置。 -
sampling_kwargs
: 采样参数。
返回值:
- 采样的 token。
def decode_one_token(
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, block_mask: BlockMask, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
block_index = input_pos // block_mask.BLOCK_SIZE[0]
mask = block_mask[:, :, block_index]
mask.mask_mod = block_mask.mask_mod
mask.seq_lengths = (1, model.max_seq_length)
logits = model(mask, x, input_pos)
return sample(logits, **sampling_kwargs)
功能: 解码一个 token。
参数:
-
model
: Transformer 模型。 -
x
: 输入张量。 -
input_pos
: 输入位置。 -
block_mask
: 块掩码。 -
sampling_kwargs
: 采样参数。
返回值:
- 采样的 token 和概率。
def decode_n_tokens(
model: Transformer,
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
callback=lambda _: _,
**sampling_kwargs,
):
block_mask = create_block_mask(
causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device=cur_token.device
)
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
next_token, next_prob = decode_one_token(model, cur_token, input_pos, block_mask, **sampling_kwargs)
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.clone()
return new_tokens, new_probs
功能: 解码多个 token。
参数:
-
model
: Transformer 模型。 -
cur_token
: 当前 token。 -
input_pos
: 输入位置。 -
num_new_tokens
: 需要生成的 token 数量。 -
callback
: 回调函数。 -
sampling_kwargs
: 采样参数。
返回值:
- 生成的 token 列表和概率列表。
def model_forward(model, x, input_pos):
return model(x, input_pos)
功能: 执行模型的前向传播。
参数:
-
model
: Transformer 模型。 -
x
: 输入张量。 -
input_pos
: 输入位置。
返回值:
- 模型的输出。
def speculative_decode(
model: Transformer,
draft_model: Transformer,
cur_token: torch.Tensor,
input_pos: int,
speculate_k: int,
**sampling_kwargs,
) -> torch.Tensor:
# draft model inference sequentially
device = cur_token.device
orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
draft_tokens, draft_probs = decode_n_tokens(
draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs
)
draft_tokens = torch.cat(draft_tokens)
# parallel inference on target model using draft tokens
target_logits = model_forward(
model,
torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device),
)
target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
draft_probs = torch.stack(draft_probs)
# q: target prob, p: draft prob
# q >= p: always accept draft token
# q < p: q/p prob to accept draft token
p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k] / p)
rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
accept_length = speculate_k + 1
last_token = multinomial_sample_one_no_sync(target_probs[-1])
# fill last token into draft model
model_forward(
draft_model,
draft_tokens[-1].view(1, -1),
orig_input_pos + speculate_k,
)
return torch.cat([draft_tokens, last_token])
else:
accept_length = rejected_locations[0].item()
p = draft_probs[accept_length]
q = target_probs[accept_length]
new = q - p
new = torch.where(new > 0, new, 0.0)
new = new / new.sum()
next_token = multinomial_sample_one_no_sync(new)
return torch.cat([draft_tokens[:accept_length], next_token])
功能: 执行推测解码。
参数:
-
model
: Transformer 模型。 -
draft_model
: 草稿模型。 -
cur_token
: 当前 token。 -
input_pos
: 输入位置。 -
speculate_k
: 推测深度。 -
sampling_kwargs
: 采样参数。
返回值:
- 生成的 token。
@torch.no_grad()
def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
batch_size: int,
*,
interactive: bool,
draft_model: Transformer,
speculate_k: Optional[int] = 8,
callback=lambda x: x,
**sampling_kwargs,
) -> torch.Tensor:
# 主生成函数的实现
...
功能: 主生成函数,负责协调模型加载、文本编码和生成过程。
参数:
-
model
: Transformer 模型。 -
prompt
: 输入提示。 -
max_new_tokens
: 最大生成的 token 数量。 -
batch_size
: 批量大小。 -
interactive
: 是否启用交互模式。 -
draft_model
: 草稿模型。 -
speculate_k
: 推测深度。 -
callback
: 回调函数。 -
sampling_kwargs
: 采样参数。
返回值:
- 生成的序列和统计信息。
def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)
功能: 将字符串编码为 token IDs。
参数:
-
tokenizer
: 分词器。 -
string
: 输入字符串。 -
bos
: 是否添加开头标记。 -
device
: 目标设备。
返回值:
- 编码后的 token IDs。
def _load_model(checkpoint_path, device, precision, use_tp):
# 加载模型的实现
...
功能: 从检查点加载模型。
参数:
-
checkpoint_path
: 模型检查点路径。 -
device
: 目标设备。 -
precision
: 模型精度。 -
use_tp
: 是否使用张量并行。
返回值:
- 加载的模型。
def _get_model_size(model):
model_size = 0
params = 0
for name, child in model.named_children():
if not isinstance(child, torch.nn.Embedding):
model_size += sum(
[p.numel() * p.dtype.itemsize for p in itertools.chain(child.parameters(), child.buffers())]
)
params += sum([p.numel() for p in itertools.chain(child.parameters(), child.buffers())])
return model_size, params
功能: 获取模型大小和参数数量。
参数:
-
model
: 模型。
返回值:
- 模型大小和参数数量。
def main(
prompt: Union[int, str] = "Hello, my name is",
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
batch_size: int = 1,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
draft_checkpoint_path: Optional[Path] = None,
speculate_k: int = 5,
device=default_device,
) -> None:
# 主函数的实现
...
功能: 主函数,负责初始化和运行文本生成过程。
参数:
-
prompt
: 输入提示。 -
interactive
: 是否启用交互模式。 -
num_samples
: 生成的样本数量。 -
max_new_tokens
: 最大生成的 token 数量。 -
batch_size
: 批量大小。 -
top_k
: Top-K 采样参数。 -
temperature
: 温度参数。 -
checkpoint_path
: 模型检查点路径。 -
compile
: 是否编译模型。 -
compile_prefill
: 是否编译预填充阶段。 -
profile
: 性能分析文件路径。 -
draft_checkpoint_path
: 草稿模型检查点路径。 -
speculate_k
: 推测深度。 -
device
: 目标设备。
逻辑:
- 初始化分布式训练环境。
- 加载模型和分词器。
- 编码输入提示。
- 根据配置进行文本生成。
- 输出生成结果和性能指标。
# 加载模型和分词器
checkpoint_path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth")
model = _load_model(checkpoint_path, device="cuda", precision=torch.bfloat16, use_tp=False)
tokenizer = get_tokenizer(checkpoint_path.parent / "tokenizer.model", checkpoint_path)
# 编码输入提示
prompt = "Hello, my name is"
encoded = encode_tokens(tokenizer, prompt, bos=True, device="cuda")
# 生成文本
generated_text, metrics = generate(
model,
encoded,
max_new_tokens=100,
batch_size=1,
temperature=0.8,
top_k=200,
)
print(tokenizer.decode(generated_text[0].tolist()))
# 加载草稿模型和目标模型
draft_model = _load_model(draft_checkpoint_path, device="cuda", precision=torch.bfloat16, use_tp=False)
target_model = _load_model(checkpoint_path, device="cuda", precision=torch.bfloat16, use_tp=False)
# 推测解码
speculative_decode(target_model, draft_model, cur_token, input_pos, speculate_k=5)
- 高效解码: 通过缓存机制和编译优化,提高生成效率。
- 灵活配置: 支持多种解码策略和参数配置。
- 张量并行: 支持多 GPU 环境下的并行计算。
- 量化支持: 支持 int8 和 int4 量化,减少内存占用。
- 确保输入提示格式正确,符合模型期望。
- 在长时间生成任务中,监控内存使用情况。
- 根据硬件配置调整批量大小和序列长度。
- 文本生成
- 对话系统
- 代码生成
- 机器翻译
- 摘要生成
本工具通过封装高效的解码策略和优化技术,为自然语言处理任务提供了强大的生成能力,特别适合需要高性能和低延迟的应用场景。
本模块提供了文本分词和编码的接口,支持多种分词器实现,包括 SentencePiece 和 Tiktoken。它通过统一的接口 TokenizerInterface
封装了不同的分词器,方便在不同模型和任务中使用。
抽象基类,定义了分词器的基本接口。
class TokenizerInterface:
def __init__(self, model_path):
self.model_path = model_path
def encode(self, text):
raise NotImplementedError("This method should be overridden by subclasses.")
def decode(self, tokens):
raise NotImplementedError("This method should be overridden by subclasses.")
def bos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")
def eos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")
功能: 定义分词器的基本接口。
主要方法:
-
encode
: 将文本编码为 token IDs。 -
decode
: 将 token IDs 解码为文本。 -
bos_id
: 获取开头标记的 ID。 -
eos_id
: 获取结尾标记的 ID。
SentencePiece 分词器的封装实现。
class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
self.processor = spm.SentencePieceProcessor(str(model_path))
def encode(self, text):
return self.processor.EncodeAsIds(text)
def decode(self, tokens):
return self.processor.DecodeIds(tokens)
def bos_id(self):
return self.processor.bos_id()
def eos_id(self):
return self.processor.eos_id()
功能: 使用 SentencePiece 库进行文本分词和编码。
主要方法:
-
encode
: 将文本编码为 token IDs。 -
decode
: 将 token IDs 解码为文本。 -
bos_id
: 获取开头标记的 ID。 -
eos_id
: 获取结尾标记的 ID。
Tiktoken 分词器的封装实现。
class TiktokenWrapper(TokenizerInterface):
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
def __init__(self, model_path):
super().__init__(model_path)
assert os.path.isfile(model_path), str(model_path)
mergeable_ranks = load_tiktoken_bpe(str(model_path))
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range(5, self.num_reserved_special_tokens - 5)]
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
# BOS / EOS token IDs
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
self._eos_id: int = self.special_tokens["<|end_of_text|>"]
def encode(self, text):
return self.model.encode(text)
def decode(self, tokens):
return self.model.decode(tokens)
def bos_id(self):
return self._bos_id
def eos_id(self):
return self._eos_id
功能: 使用 Tiktoken 库进行文本分词和编码。
主要方法:
-
encode
: 将文本编码为 token IDs。 -
decode
: 将 token IDs 解码为文本。 -
bos_id
: 获取开头标记的 ID。 -
eos_id
: 获取结尾标记的 ID。
分词器工厂函数,根据模型名称返回适当的分词器实例。
def get_tokenizer(tokenizer_model_path, model_name):
"""
Factory function to get the appropriate tokenizer based on the model name.
Args:
- tokenizer_model_path (str): The file path to the tokenizer model.
- model_name (str): The name of the model, used to determine the tokenizer type.
Returns:
- TokenizerInterface: An instance of a tokenizer.
"""
if "llama-3" in str(model_name).lower():
return TiktokenWrapper(tokenizer_model_path)
else:
return SentencePieceWrapper(tokenizer_model_path)
功能: 根据模型名称返回相应的分词器实例。
参数:
-
tokenizer_model_path
: 分词器模型路径。 -
model_name
: 模型名称,用于确定分词器类型。
返回值:
-
TokenizerInterface
: 分词器实例。
# 获取分词器
tokenizer_model_path = "path/to/tokenizer.model"
model_name = "llama-3"
tokenizer = get_tokenizer(tokenizer_model_path, model_name)
# 编码文本
text = "Hello, world!"
tokens = tokenizer.encode(text)
print(tokens)
# 解码 token IDs
decoded_text = tokenizer.decode(tokens)
print(decoded_text)
# 获取开头和结尾标记 ID
bos_id = tokenizer.bos_id()
eos_id = tokenizer.eos_id()
print(f"BOS ID: {bos_id}, EOS ID: {eos_id}")
-
统一接口: 通过
TokenizerInterface
提供统一的分词器接口。 - 多种实现: 支持 SentencePiece 和 Tiktoken 两种分词器实现。
- 灵活配置: 根据模型名称自动选择合适的分词器。
- 确保分词器模型文件存在且路径正确。
- 根据模型需求选择合适的分词器类型。
- 文本生成
- 机器翻译
- 摘要生成
- 对话系统
本模块通过封装不同的分词器实现,为自然语言处理任务提供了灵活且高效的文本处理能力。
import math from dataclasses import dataclass from typing import Optional
import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F from torch.nn.attention.flex_attention import ( _mask_mod_signature, BlockMask, flex_attention, )
def find_multiple(n: int, k: int) -> int: if n % k == 0: return n return n + k - (n % k)
def get_mask_mod(mask_mod: _mask_mod_signature, offset: int): def _mask_mod(b, h, q, kv): return mask_mod(b, h, q + offset, kv)
return _mask_mod
@dataclass class ModelArgs: block_size: int = 2048 vocab_size: int = 32000 n_layer: int = 32 n_head: int = 32 dim: int = 4096 intermediate_size: int = None n_local_heads: int = -1 head_dim: int = 64 rope_base: float = 10000 norm_eps: float = 1e-5 rope_scaling: Optional[dict] = None
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [config for config in transformer_configs if config.lower() in str(name).lower()]
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
# take longer name (as it have more symbols matched)
if len(config) > 1:
config.sort(key=len, reverse=True)
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
return cls(**transformer_configs[config[0]])
transformer_configs = { "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000), "7B": dict(n_layer=32, n_head=32, dim=4096), "13B": dict(n_layer=40, n_head=40, dim=5120), "30B": dict(n_layer=60, n_head=52, dim=6656), "34B": dict( n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000 ), # CodeLlama-34B-Python-hf "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), "stories15M": dict(n_layer=6, n_head=6, dim=288), "stories110M": dict(n_layer=12, n_head=12, dim=768), "llama-3-8b": dict( block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, ), "llama-3-70b": dict( block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000, ), "llama-3.1-8b": dict( block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), ), "llama-3.1-70b": dict( block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000, rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), ), "llama-3.1-405b": dict( block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), ), }
class KVCache(nn.Module): def init(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): super().init() cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
return k_out, v_out
class Transformer(nn.Module): def init(self, config: ModelArgs) -> None: super().init() self.config = config
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.freqs_cis: Optional[Tensor] = None
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1
self.get_mask_mod = get_mask_mod
def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype = self.output.weight.dtype
# For quantized layers, dtype is encoded in scales
if hasattr(self.output, "scales"):
dtype = self.output.scales.dtype
elif hasattr(self.output, "scales_and_zeros"):
dtype = self.output.scales_and_zeros.dtype
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
self.freqs_cis = precompute_freqs_cis(
self.config.block_size,
self.config.dim // self.config.n_head,
self.config.rope_base,
dtype,
self.config.rope_scaling,
)
def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask.mask_mod = self.get_mask_mod(mask.mask_mod, input_pos[0])
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)
for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
x = self.norm(x)
logits = self.output(x)
return logits
@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))
class TransformerBlock(nn.Module): def init(self, config: ModelArgs) -> None: super().init() self.attention = Attention(config) self.feed_forward = FeedForward(config) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps)
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Attention(nn.Module): def init(self, config: ModelArgs): super().init() assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)
y = flex_attention(q, k, v, block_mask=mask, enable_gqa=(self.n_head != self.n_local_heads))
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
y = self.wo(y)
return y
class FeedForward(nn.Module): def init(self, config: ModelArgs) -> None: super().init() self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class RMSNorm(nn.Module): def init(self, dim: int, eps: float = 1e-5): super().init() self.eps = eps self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None): factor = rope_scaling["factor"] low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] old_context_len = rope_scaling["original_max_position_embeddings"]
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis( seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16, rope_scaling: Optional[dict] = None, ) -> Tensor: freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) if rope_scaling is not None: freqs = apply_rope_scaling(freqs, rope_scaling) t = torch.arange(seq_len, device=freqs.device) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) return cache.to(dtype=dtype)
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: xshaped = x.float().reshape(*x.shape[:-1], -1, 2) freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], ], -1, )
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)