src(来自gpt‐fast的repo) - Nanji-Huaji/hierarchical-speculative-decoding GitHub Wiki

以下内容由AI生成

generate.py

Transformer 文本生成工具文档

概述

本项目提供了一个基于 Transformer 架构的高效文本生成工具,支持多种解码策略,包括自回归采样、推测解码和张量并行。它专为高性能生成任务设计,适用于各种自然语言处理应用场景。

类和函数文档

1. device_sync

def device_sync(device):
    if "cuda" in device:
        torch.cuda.synchronize(device)
    elif ("cpu" in device) or ("mps" in device):
        pass
    else:
        print(f"device={device} is not yet suppported")

功能: 根据设备类型进行同步操作。

参数:

  • device (str): 设备类型(如 "cuda""cpu")。

2. Transformer

class Transformer:
    # 模型定义和方法
    ...

功能: Transformer 模型类,负责文本生成的核心计算。

主要方法:

  • setup_caches: 设置缓存。
  • forward: 模型前向传播。

3. get_tokenizer

from tokenizer import get_tokenizer

功能: 获取分词器。

参数:

  • tokenizer_path: 分词器路径。
  • checkpoint_path: 模型检查点路径。

返回值:

  • 分词器对象。

4. multinomial_sample_one_no_sync

def multinomial_sample_one_no_sync(probs_sort):
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

功能: 进行多项式采样,无需 CUDA 同步。

参数:

  • probs_sort: 概率张量。

返回值:

  • 采样的 token 索引。

5. logits_to_probs

def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

功能: 将 logits 转换为概率分布,支持温度控制和 Top-K 过滤。

参数:

  • logits: 输入的 logits 张量。
  • temperature: 温度参数。
  • top_k: Top-K 采样参数。

返回值:

  • 转换后的概率张量。

6. sample

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    probs = logits_to_probs(logits[:, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs

功能: 从 logits 中采样下一个 token。

参数:

  • logits: 输入的 logits 张量。
  • temperature: 温度参数。
  • top_k: Top-K 采样参数。

返回值:

  • 采样的 token 索引和概率。

7. roundup

def roundup(val, multiplier):
    return ((val - 1) // multiplier + 1) * multiplier

功能: 将值向上取整到最近的倍数。

参数:

  • val: 输入值。
  • multiplier: 倍数。

返回值:

  • 向上取整后的值。

8. causal_mask

def causal_mask(b, h, q, kv):
    return q >= kv

功能: 创建因果掩码。

参数:

  • b: 批量大小。
  • h: 注意力头数。
  • q: 查询位置。
  • kv: 键值位置。

返回值:

  • 因果掩码。

9. prefill

def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
    mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device=x.device)
    logits = model(mask, x, input_pos)
    return sample(logits, **sampling_kwargs)[0]

功能: 预填充模型,生成初始 token。

参数:

  • model: Transformer 模型。
  • x: 输入张量。
  • input_pos: 输入位置。
  • sampling_kwargs: 采样参数。

返回值:

  • 采样的 token。

10. decode_one_token

def decode_one_token(
    model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, block_mask: BlockMask, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
    block_index = input_pos // block_mask.BLOCK_SIZE[0]
    mask = block_mask[:, :, block_index]
    mask.mask_mod = block_mask.mask_mod
    mask.seq_lengths = (1, model.max_seq_length)
    logits = model(mask, x, input_pos)
    return sample(logits, **sampling_kwargs)

功能: 解码一个 token。

参数:

  • model: Transformer 模型。
  • x: 输入张量。
  • input_pos: 输入位置。
  • block_mask: 块掩码。
  • sampling_kwargs: 采样参数。

返回值:

  • 采样的 token 和概率。

11. decode_n_tokens

def decode_n_tokens(
    model: Transformer,
    cur_token: torch.Tensor,
    input_pos: torch.Tensor,
    num_new_tokens: int,
    callback=lambda _: _,
    **sampling_kwargs,
):
    block_mask = create_block_mask(
        causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device=cur_token.device
    )
    new_tokens, new_probs = [], []
    for i in range(num_new_tokens):
        next_token, next_prob = decode_one_token(model, cur_token, input_pos, block_mask, **sampling_kwargs)
        input_pos += 1
        new_tokens.append(next_token.clone())
        callback(new_tokens[-1])
        new_probs.append(next_prob.clone())
        cur_token = next_token.clone()

    return new_tokens, new_probs

功能: 解码多个 token。

参数:

  • model: Transformer 模型。
  • cur_token: 当前 token。
  • input_pos: 输入位置。
  • num_new_tokens: 需要生成的 token 数量。
  • callback: 回调函数。
  • sampling_kwargs: 采样参数。

返回值:

  • 生成的 token 列表和概率列表。

12. model_forward

def model_forward(model, x, input_pos):
    return model(x, input_pos)

功能: 执行模型的前向传播。

参数:

  • model: Transformer 模型。
  • x: 输入张量。
  • input_pos: 输入位置。

返回值:

  • 模型的输出。

13. speculative_decode

def speculative_decode(
    model: Transformer,
    draft_model: Transformer,
    cur_token: torch.Tensor,
    input_pos: int,
    speculate_k: int,
    **sampling_kwargs,
) -> torch.Tensor:
    # draft model inference sequentially
    device = cur_token.device
    orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
    draft_tokens, draft_probs = decode_n_tokens(
        draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs
    )

    draft_tokens = torch.cat(draft_tokens)
    # parallel inference on target model using draft tokens
    target_logits = model_forward(
        model,
        torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
        torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device),
    )
    target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
    draft_probs = torch.stack(draft_probs)
    # q: target prob, p: draft prob
    # q >= p: always accept draft token
    # q < p: q/p prob to accept draft token
    p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
    q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
    accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k] / p)
    rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()

    if rejected_locations.shape[0] == 0:  # All draft tokens have been accepted
        accept_length = speculate_k + 1
        last_token = multinomial_sample_one_no_sync(target_probs[-1])
        # fill last token into draft model
        model_forward(
            draft_model,
            draft_tokens[-1].view(1, -1),
            orig_input_pos + speculate_k,
        )
        return torch.cat([draft_tokens, last_token])
    else:
        accept_length = rejected_locations[0].item()
        p = draft_probs[accept_length]
        q = target_probs[accept_length]
        new = q - p
        new = torch.where(new > 0, new, 0.0)
        new = new / new.sum()
        next_token = multinomial_sample_one_no_sync(new)
        return torch.cat([draft_tokens[:accept_length], next_token])

功能: 执行推测解码。

参数:

  • model: Transformer 模型。
  • draft_model: 草稿模型。
  • cur_token: 当前 token。
  • input_pos: 输入位置。
  • speculate_k: 推测深度。
  • sampling_kwargs: 采样参数。

返回值:

  • 生成的 token。

14. generate

@torch.no_grad()
def generate(
    model: Transformer,
    prompt: torch.Tensor,
    max_new_tokens: int,
    batch_size: int,
    *,
    interactive: bool,
    draft_model: Transformer,
    speculate_k: Optional[int] = 8,
    callback=lambda x: x,
    **sampling_kwargs,
) -> torch.Tensor:
    # 主生成函数的实现
    ...

功能: 主生成函数,负责协调模型加载、文本编码和生成过程。

参数:

  • model: Transformer 模型。
  • prompt: 输入提示。
  • max_new_tokens: 最大生成的 token 数量。
  • batch_size: 批量大小。
  • interactive: 是否启用交互模式。
  • draft_model: 草稿模型。
  • speculate_k: 推测深度。
  • callback: 回调函数。
  • sampling_kwargs: 采样参数。

返回值:

  • 生成的序列和统计信息。

15. encode_tokens

def encode_tokens(tokenizer, string, bos=True, device=default_device):
    tokens = tokenizer.encode(string)
    if bos:
        tokens = [tokenizer.bos_id()] + tokens
    return torch.tensor(tokens, dtype=torch.int, device=device)

功能: 将字符串编码为 token IDs。

参数:

  • tokenizer: 分词器。
  • string: 输入字符串。
  • bos: 是否添加开头标记。
  • device: 目标设备。

返回值:

  • 编码后的 token IDs。

16. _load_model

def _load_model(checkpoint_path, device, precision, use_tp):
    # 加载模型的实现
    ...

功能: 从检查点加载模型。

参数:

  • checkpoint_path: 模型检查点路径。
  • device: 目标设备。
  • precision: 模型精度。
  • use_tp: 是否使用张量并行。

返回值:

  • 加载的模型。

17. _get_model_size

def _get_model_size(model):
    model_size = 0
    params = 0
    for name, child in model.named_children():
        if not isinstance(child, torch.nn.Embedding):
            model_size += sum(
                [p.numel() * p.dtype.itemsize for p in itertools.chain(child.parameters(), child.buffers())]
            )
            params += sum([p.numel() for p in itertools.chain(child.parameters(), child.buffers())])
    return model_size, params

功能: 获取模型大小和参数数量。

参数:

  • model: 模型。

返回值:

  • 模型大小和参数数量。

18. main

def main(
    prompt: Union[int, str] = "Hello, my name is",
    interactive: bool = False,
    num_samples: int = 5,
    max_new_tokens: int = 100,
    batch_size: int = 1,
    top_k: int = 200,
    temperature: float = 0.8,
    checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
    compile: bool = True,
    compile_prefill: bool = False,
    profile: Optional[Path] = None,
    draft_checkpoint_path: Optional[Path] = None,
    speculate_k: int = 5,
    device=default_device,
) -> None:
    # 主函数的实现
    ...

功能: 主函数,负责初始化和运行文本生成过程。

参数:

  • prompt: 输入提示。
  • interactive: 是否启用交互模式。
  • num_samples: 生成的样本数量。
  • max_new_tokens: 最大生成的 token 数量。
  • batch_size: 批量大小。
  • top_k: Top-K 采样参数。
  • temperature: 温度参数。
  • checkpoint_path: 模型检查点路径。
  • compile: 是否编译模型。
  • compile_prefill: 是否编译预填充阶段。
  • profile: 性能分析文件路径。
  • draft_checkpoint_path: 草稿模型检查点路径。
  • speculate_k: 推测深度。
  • device: 目标设备。

逻辑:

  • 初始化分布式训练环境。
  • 加载模型和分词器。
  • 编码输入提示。
  • 根据配置进行文本生成。
  • 输出生成结果和性能指标。

使用示例

基本用法

# 加载模型和分词器
checkpoint_path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth")
model = _load_model(checkpoint_path, device="cuda", precision=torch.bfloat16, use_tp=False)
tokenizer = get_tokenizer(checkpoint_path.parent / "tokenizer.model", checkpoint_path)

# 编码输入提示
prompt = "Hello, my name is"
encoded = encode_tokens(tokenizer, prompt, bos=True, device="cuda")

# 生成文本
generated_text, metrics = generate(
    model,
    encoded,
    max_new_tokens=100,
    batch_size=1,
    temperature=0.8,
    top_k=200,
)

print(tokenizer.decode(generated_text[0].tolist()))

推测解码示例

# 加载草稿模型和目标模型
draft_model = _load_model(draft_checkpoint_path, device="cuda", precision=torch.bfloat16, use_tp=False)
target_model = _load_model(checkpoint_path, device="cuda", precision=torch.bfloat16, use_tp=False)

# 推测解码
speculative_decode(target_model, draft_model, cur_token, input_pos, speculate_k=5)

关键特性

  1. 高效解码: 通过缓存机制和编译优化,提高生成效率。
  2. 灵活配置: 支持多种解码策略和参数配置。
  3. 张量并行: 支持多 GPU 环境下的并行计算。
  4. 量化支持: 支持 int8 和 int4 量化,减少内存占用。

注意事项

  1. 确保输入提示格式正确,符合模型期望。
  2. 在长时间生成任务中,监控内存使用情况。
  3. 根据硬件配置调整批量大小和序列长度。

应用场景

  • 文本生成
  • 对话系统
  • 代码生成
  • 机器翻译
  • 摘要生成

本工具通过封装高效的解码策略和优化技术,为自然语言处理任务提供了强大的生成能力,特别适合需要高性能和低延迟的应用场景。

tokenizer.py

Tokenizer 模块文档

概述

本模块提供了文本分词和编码的接口,支持多种分词器实现,包括 SentencePiece 和 Tiktoken。它通过统一的接口 TokenizerInterface 封装了不同的分词器,方便在不同模型和任务中使用。

类和方法文档

1. TokenizerInterface

抽象基类,定义了分词器的基本接口。

class TokenizerInterface:
    def __init__(self, model_path):
        self.model_path = model_path

    def encode(self, text):
        raise NotImplementedError("This method should be overridden by subclasses.")

    def decode(self, tokens):
        raise NotImplementedError("This method should be overridden by subclasses.")

    def bos_id(self):
        raise NotImplementedError("This method should be overridden by subclasses.")

    def eos_id(self):
        raise NotImplementedError("This method should be overridden by subclasses.")

功能: 定义分词器的基本接口。

主要方法:

  • encode: 将文本编码为 token IDs。
  • decode: 将 token IDs 解码为文本。
  • bos_id: 获取开头标记的 ID。
  • eos_id: 获取结尾标记的 ID。

2. SentencePieceWrapper

SentencePiece 分词器的封装实现。

class SentencePieceWrapper(TokenizerInterface):
    def __init__(self, model_path):
        super().__init__(model_path)
        self.processor = spm.SentencePieceProcessor(str(model_path))

    def encode(self, text):
        return self.processor.EncodeAsIds(text)

    def decode(self, tokens):
        return self.processor.DecodeIds(tokens)

    def bos_id(self):
        return self.processor.bos_id()

    def eos_id(self):
        return self.processor.eos_id()

功能: 使用 SentencePiece 库进行文本分词和编码。

主要方法:

  • encode: 将文本编码为 token IDs。
  • decode: 将 token IDs 解码为文本。
  • bos_id: 获取开头标记的 ID。
  • eos_id: 获取结尾标记的 ID。

3. TiktokenWrapper

Tiktoken 分词器的封装实现。

class TiktokenWrapper(TokenizerInterface):
    special_tokens: Dict[str, int]

    num_reserved_special_tokens = 256

    pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501

    def __init__(self, model_path):
        super().__init__(model_path)
        assert os.path.isfile(model_path), str(model_path)
        mergeable_ranks = load_tiktoken_bpe(str(model_path))
        num_base_tokens = len(mergeable_ranks)
        special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>",
            "<|reserved_special_token_3|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|reserved_special_token_4|>",
            "<|eot_id|>",  # end of turn
        ] + [f"<|reserved_special_token_{i}|>" for i in range(5, self.num_reserved_special_tokens - 5)]
        self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
        self.model = tiktoken.Encoding(
            name=Path(model_path).name,
            pat_str=self.pat_str,
            mergeable_ranks=mergeable_ranks,
            special_tokens=self.special_tokens,
        )
        # BOS / EOS token IDs
        self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
        self._eos_id: int = self.special_tokens["<|end_of_text|>"]

    def encode(self, text):
        return self.model.encode(text)

    def decode(self, tokens):
        return self.model.decode(tokens)

    def bos_id(self):
        return self._bos_id

    def eos_id(self):
        return self._eos_id

功能: 使用 Tiktoken 库进行文本分词和编码。

主要方法:

  • encode: 将文本编码为 token IDs。
  • decode: 将 token IDs 解码为文本。
  • bos_id: 获取开头标记的 ID。
  • eos_id: 获取结尾标记的 ID。

4. get_tokenizer 函数

分词器工厂函数,根据模型名称返回适当的分词器实例。

def get_tokenizer(tokenizer_model_path, model_name):
    """
    Factory function to get the appropriate tokenizer based on the model name.

    Args:
    - tokenizer_model_path (str): The file path to the tokenizer model.
    - model_name (str): The name of the model, used to determine the tokenizer type.

    Returns:
    - TokenizerInterface: An instance of a tokenizer.
    """

    if "llama-3" in str(model_name).lower():
        return TiktokenWrapper(tokenizer_model_path)
    else:
        return SentencePieceWrapper(tokenizer_model_path)

功能: 根据模型名称返回相应的分词器实例。

参数:

  • tokenizer_model_path: 分词器模型路径。
  • model_name: 模型名称,用于确定分词器类型。

返回值:

  • TokenizerInterface: 分词器实例。

使用示例

基本用法

# 获取分词器
tokenizer_model_path = "path/to/tokenizer.model"
model_name = "llama-3"
tokenizer = get_tokenizer(tokenizer_model_path, model_name)

# 编码文本
text = "Hello, world!"
tokens = tokenizer.encode(text)
print(tokens)

# 解码 token IDs
decoded_text = tokenizer.decode(tokens)
print(decoded_text)

# 获取开头和结尾标记 ID
bos_id = tokenizer.bos_id()
eos_id = tokenizer.eos_id()
print(f"BOS ID: {bos_id}, EOS ID: {eos_id}")

关键特性

  1. 统一接口: 通过 TokenizerInterface 提供统一的分词器接口。
  2. 多种实现: 支持 SentencePiece 和 Tiktoken 两种分词器实现。
  3. 灵活配置: 根据模型名称自动选择合适的分词器。

注意事项

  1. 确保分词器模型文件存在且路径正确。
  2. 根据模型需求选择合适的分词器类型。

应用场景

  • 文本生成
  • 机器翻译
  • 摘要生成
  • 对话系统

本模块通过封装不同的分词器实现,为自然语言处理任务提供了灵活且高效的文本处理能力。

model.py

Copyright (c) Meta Platforms, Inc. and affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

This file comes from gpt-fast project by Meta AI

import math from dataclasses import dataclass from typing import Optional

import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F from torch.nn.attention.flex_attention import ( _mask_mod_signature, BlockMask, flex_attention, )

def find_multiple(n: int, k: int) -> int: if n % k == 0: return n return n + k - (n % k)

def get_mask_mod(mask_mod: _mask_mod_signature, offset: int): def _mask_mod(b, h, q, kv): return mask_mod(b, h, q + offset, kv)

return _mask_mod

@dataclass class ModelArgs: block_size: int = 2048 vocab_size: int = 32000 n_layer: int = 32 n_head: int = 32 dim: int = 4096 intermediate_size: int = None n_local_heads: int = -1 head_dim: int = 64 rope_base: float = 10000 norm_eps: float = 1e-5 rope_scaling: Optional[dict] = None

def __post_init__(self):
    if self.n_local_heads == -1:
        self.n_local_heads = self.n_head
    if self.intermediate_size is None:
        hidden_dim = 4 * self.dim
        n_hidden = int(2 * hidden_dim / 3)
        self.intermediate_size = find_multiple(n_hidden, 256)
    self.head_dim = self.dim // self.n_head

@classmethod
def from_name(cls, name: str):
    if name in transformer_configs:
        return cls(**transformer_configs[name])
    # fuzzy search
    config = [config for config in transformer_configs if config.lower() in str(name).lower()]

    # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
    # take longer name (as it have more symbols matched)
    if len(config) > 1:
        config.sort(key=len, reverse=True)
        assert len(config[0]) != len(config[1]), name  # make sure only one 'best' match

    return cls(**transformer_configs[config[0]])

transformer_configs = { "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000), "7B": dict(n_layer=32, n_head=32, dim=4096), "13B": dict(n_layer=40, n_head=40, dim=5120), "30B": dict(n_layer=60, n_head=52, dim=6656), "34B": dict( n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000 ), # CodeLlama-34B-Python-hf "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), "stories15M": dict(n_layer=6, n_head=6, dim=288), "stories110M": dict(n_layer=12, n_head=12, dim=768), "llama-3-8b": dict( block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, ), "llama-3-70b": dict( block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000, ), "llama-3.1-8b": dict( block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), ), "llama-3.1-70b": dict( block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000, rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), ), "llama-3.1-405b": dict( block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), ), }

class KVCache(nn.Module): def init(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): super().init() cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

def update(self, input_pos, k_val, v_val):
    # input_pos: [S], k_val: [B, H, S, D]
    assert input_pos.shape[0] == k_val.shape[2]

    k_out = self.k_cache
    v_out = self.v_cache
    k_out[:, :, input_pos] = k_val
    v_out[:, :, input_pos] = v_val

    return k_out, v_out

class Transformer(nn.Module): def init(self, config: ModelArgs) -> None: super().init() self.config = config

    self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
    self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
    self.norm = RMSNorm(config.dim, eps=config.norm_eps)
    self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

    self.freqs_cis: Optional[Tensor] = None
    self.mask_cache: Optional[Tensor] = None
    self.max_batch_size = -1
    self.max_seq_length = -1
    self.get_mask_mod = get_mask_mod

def setup_caches(self, max_batch_size, max_seq_length):
    if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
        return
    head_dim = self.config.dim // self.config.n_head
    max_seq_length = find_multiple(max_seq_length, 8)
    self.max_seq_length = max_seq_length
    self.max_batch_size = max_batch_size
    dtype = self.output.weight.dtype
    # For quantized layers, dtype is encoded in scales
    if hasattr(self.output, "scales"):
        dtype = self.output.scales.dtype
    elif hasattr(self.output, "scales_and_zeros"):
        dtype = self.output.scales_and_zeros.dtype
    for b in self.layers:
        b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

    self.freqs_cis = precompute_freqs_cis(
        self.config.block_size,
        self.config.dim // self.config.n_head,
        self.config.rope_base,
        dtype,
        self.config.rope_scaling,
    )

def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
    assert self.freqs_cis is not None, "Caches must be initialized first"
    mask.mask_mod = self.get_mask_mod(mask.mask_mod, input_pos[0])
    freqs_cis = self.freqs_cis[input_pos]
    x = self.tok_embeddings(idx)

    for i, layer in enumerate(self.layers):
        x = layer(x, input_pos, freqs_cis, mask)
    x = self.norm(x)
    logits = self.output(x)
    return logits

@classmethod
def from_name(cls, name: str):
    return cls(ModelArgs.from_name(name))

class TransformerBlock(nn.Module): def init(self, config: ModelArgs) -> None: super().init() self.attention = Attention(config) self.feed_forward = FeedForward(config) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor:
    h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
    out = h + self.feed_forward(self.ffn_norm(h))
    return out

class Attention(nn.Module): def init(self, config: ModelArgs): super().init() assert config.dim % config.n_head == 0

    total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
    # key, query, value projections for all heads, but in a batch
    self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
    self.wo = nn.Linear(config.dim, config.dim, bias=False)
    self.kv_cache = None

    self.n_head = config.n_head
    self.head_dim = config.head_dim
    self.n_local_heads = config.n_local_heads
    self.dim = config.dim
    self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
    if prefix + "wq.weight" in state_dict:
        wq = state_dict.pop(prefix + "wq.weight")
        wk = state_dict.pop(prefix + "wk.weight")
        wv = state_dict.pop(prefix + "wv.weight")
        state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor:
    bsz, seqlen, _ = x.shape

    kv_size = self.n_local_heads * self.head_dim
    q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

    q = q.view(bsz, seqlen, self.n_head, self.head_dim)
    k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
    v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

    q = apply_rotary_emb(q, freqs_cis)
    k = apply_rotary_emb(k, freqs_cis)

    q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

    if self.kv_cache is not None:
        k, v = self.kv_cache.update(input_pos, k, v)

    y = flex_attention(q, k, v, block_mask=mask, enable_gqa=(self.n_head != self.n_local_heads))

    y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

    y = self.wo(y)
    return y

class FeedForward(nn.Module): def init(self, config: ModelArgs) -> None: super().init() self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
    return self.w2(F.silu(self.w1(x)) * self.w3(x))

class RMSNorm(nn.Module): def init(self, dim: int, eps: float = 1e-5): super().init() self.eps = eps self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
    return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

def forward(self, x: Tensor) -> Tensor:
    output = self._norm(x.float()).type_as(x)
    return output * self.weight

def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None): factor = rope_scaling["factor"] low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] old_context_len = rope_scaling["original_max_position_embeddings"]

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
    wavelen = 2 * math.pi / freq
    if wavelen < high_freq_wavelen:
        new_freqs.append(freq)
    elif wavelen > low_freq_wavelen:
        new_freqs.append(freq / factor)
    else:
        assert low_freq_wavelen != high_freq_wavelen
        smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
        new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

def precompute_freqs_cis( seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16, rope_scaling: Optional[dict] = None, ) -> Tensor: freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) if rope_scaling is not None: freqs = apply_rope_scaling(freqs, rope_scaling) t = torch.arange(seq_len, device=freqs.device) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) return cache.to(dtype=dtype)

def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: xshaped = x.float().reshape(*x.shape[:-1], -1, 2) freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], ], -1, )

x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
⚠️ **GitHub.com Fallback** ⚠️