LLM 大模型推理机制全解析:从 Transformer 到推理优化的完整技术栈

一、引言

大语言模型(LLM)的推理过程远不止”输入→输出”的黑盒。从 Transformer 架构的逐层计算,到自注意力机制的数学本质,再到 KV Cache、采样策略和新兴的推理优化技术,每一个环节都是提升模型性能和降低部署成本的关键。本文将系统梳理 LLM 推理的完整技术栈,深入每一层细节。

二、Transformer 架构逐层拆解

2.1 整体架构

LLM 的核心是 Decoder-only 的 Transformer 架构。它由多层相同的 Decoder Block 堆叠而成,每一层包含以下核心组件:

flowchart TD
    subgraph "单层 Decoder Block"
        A[输入 Token Embedding] --> B[位置编码]
        B --> C[Masked Self-Attention]
        C --> D[残差连接 + LayerNorm]
        D --> E[Feed-Forward Network]
        E --> F[残差连接 + LayerNorm]
        F --> G[下一层]
    end

2.2 Token Embedding 层

每个 token 被映射为一个 d_model 维度的稠密向量(通常是 4096、8192 等),形成词嵌入矩阵:

词嵌入维度: |V| × d_model
V = 词表大小(约 32k ~ 128k)
d_model = 隐藏层维度(4096 或更高)

关键计算: 在推理的最终层,使用词嵌入矩阵的转置对输出做 矩阵乘法(logits),得到每个 token 的分数:

# 简化代码
logits = output @ embedding_matrix.T  # [batch, seq_len, d_model] @ [d_model, vocab]
probs = softmax(logits[:, -1, :])     # 取最后一个位置的分布
next_token = sample(probs)

2.3 残差连接与 LayerNorm

每一层子层的输出都通过残差连接与 LayerNorm 组合:

output = LayerNorm(x + Sublayer(x))

其中 LayerNorm(层归一化)的计算为:

def layer_norm(x, gamma, beta, eps=1e-5):
    mean = x.mean(dim=-1, keepdim=True)
    var  = x.var(dim=-1, keepdim=True)
    return gamma * (x - mean) / sqrt(var + eps) + beta

与 BatchNorm 不同,LayerNorm 在 特征维度(最后一个维度)上做归一化,不受序列长度影响,更适合 Transformer。

三、自注意力机制(Self-Attention)深度拆解

3.1 Scaled Dot-Product Attention

注意力机制的核心是一个 查询-键-值 的三元组计算:

def attention(Q, K, V, mask=None):
    # Q, K, V: [batch, num_heads, seq_len, d_k]
    scores = Q @ K.transpose(-2, -1)  # [batch, h, seq, seq]
    scores = scores / sqrt(d_k)       # 缩放

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    attn_weights = softmax(scores, dim=-1)
    output = attn_weights @ V         # [batch, h, seq, d_k]
    return output

为什么需要缩放 sqrt(d_k) 当 d_k 较大时,点积的结果方差较大,导致 softmax 的梯度进入饱和区域(梯度极小),缩放后方差归一化,梯度更稳定。

计算 参数量 计算复杂度 内存
Q/K/V 线性映射 3 × d_model × d_model
注意力分数 O(n² × d_k) O(n²)
加权求和 O(n² × d_v) O(n²)

核心瓶颈: 注意力分数的计算复杂度和内存占用都是序列长度的 二次方 O(n²),这直接限制了大模型的上下文长度。

3.2 Multi-Head Attention

Multi-Head 将注意力计算分解为多个 “头”,每个头学习不同的注意力模式:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x).reshape(batch, seq, h, d_k).transpose(1,2)
        K = self.W_k(x).reshape(batch, seq, h, d_k).transpose(1,2)
        V = self.W_v(x).reshape(batch, seq, h, d_k).transpose(1,2)

        out = scaled_dot_attention(Q, K, V, mask)
        out = out.transpose(1,2).reshape(batch, seq, -1)
        return self.W_o(out)
flowchart LR
    subgraph "多头机制"
        X[输入向量] --> P[线性投影]
        P --> H1[1: 捕捉局部语义]
        P --> H2[2: 捕捉长距离依赖]
        P --> H3[3: 捕捉语法信息]
        P --> Hn[n: 捕捉其他模式]
        H1 --> Concat[拼接]
        H2 --> Concat
        H3 --> Concat
        Hn --> Concat
        Concat --> O[输出投影]
    end

不同头的分工: 研究表明,不同注意力头确实学习了不同的模式——有的关注相邻词(局部语法),有的关注句子尾部(全局语义),还有的关注特殊位置(如句首标记)。

3.3 Masked Self-Attention(因果掩码)

在 decoder-only 架构中,每个 token 只能关注它之前的 token(包括自身),不能”偷看”未来的 token:

def create_causal_mask(seq_len):
    # 上三角矩阵,位置 i 只能看到位置 0..i
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask  # [[1,0,0], [1,1,0], [1,1,1]]

四、KV Cache 推理加速

4.1 为什么需要 KV Cache

在自回归生成中,模型逐 token 生成。如果每次生成新 token 时都重新计算所有历史位置的 K 和 V,会产生大量冗余计算:

# ❌ 无 KV Cache:每次重新计算整段历史
def generate_without_cache(model, prompt, max_len):
    for i in range(max_len):
        # 每次都从 0 到 i 计算所有位置的 K/V
        logits = model(prompt[:, :i])
        next_token = sample(logits[:, -1, :])
        prompt = torch.cat([prompt, next_token], dim=-1)

# ✅ 有 KV Cache:只计算新 token 的 K/V
def generate_with_cache(model, prompt, max_len):
    # 第一次:计算所有 prompt token 的 K/V
    logits, kv_cache = model.forward_with_cache(prompt)

    for i in range(max_len):
        next_token = sample(logits[:, -1, :])
        # 只传入新 token,利用缓存中的历史 K/V
        logits, kv_cache = model.forward_with_cache(
            next_token, kv_cache)

4.2 KV Cache 的内存开销

KV Cache 的内存消耗随着序列长度线性增长(但”线性”在长序列下也相当可观):

KV Cache 内存 = 2 × num_layers × num_heads × d_k × seq_len × precision
              = 2 × L × h × d_k × n × bytes_per_elem

示例:LLaMA-7B(L=32, h=32, d_k=128, FP16=2bytes)
seq=2048: 2×32×32×128×2048×2 ≈ 1.07 GB
seq=8192: 2×32×32×128×8192×2 ≈ 4.29 GB
seq=32768: ~17.2 GB
模型 参数量 seq_len=2048 KV Cache seq_len=32768 KV Cache
LLaMA-7B 7B 1 GB 16 GB
LLaMA-70B 70B 8 GB 128 GB
Qwen-72B 72B 7.5 GB 120 GB

优化方向: Multi-Query Attention (MQA) 和 Grouped-Query Attention (GQA):多个头共享 K/V 以节省缓存。

4.3 GQA 对比 MHA

flowchart TD
    subgraph "MHA: 标准多头"
        MHA_Q[Q: 32 heads] --> MHA_K[K: 32 heads]
        MHA_Q --> MHA_V[V: 32 heads]
    end
    subgraph "GQA: 分组查询"
        GQA_Q[Q: 32 heads] --> GQA_K[K: 8 heads]
        GQA_Q --> GQA_V[V: 8 heads]
    end
    subgraph "MQA: 单键值"
        MQA_Q[Q: 32 heads] --> MQA_K[K: 1 head]
        MQA_Q --> MQA_V[V: 1 head]
    end

五、采样策略(Decoding Strategies)

5.1 Temperature 缩放

温度参数控制 softmax 输出的概率分布的”尖锐程度”:

def temperature_scale(logits, temperature):
    # temperature > 1: 分布更均匀(更随机)
    # temperature = 1: 原始分布
    # temperature < 1: 分布更尖锐(更确定)
    # temperature → 0: 退化为argmax
    return logits / temperature

# 示例
logits = torch.tensor([1.0, 2.0, 3.0, 0.5])

print(softmax(logits / 0.5))   # 更尖锐: [0.04, 0.29, 0.65, 0.02]
print(softmax(logits / 1.0))   # 原始:   [0.11, 0.24, 0.52, 0.13]
print(softmax(logits / 2.0))   # 更均匀: [0.19, 0.26, 0.35, 0.20]

5.2 Top-K 采样

只从概率最高的 K 个 token 中采样,过滤低概率 token:

def top_k_sampling(logits, k=50):
    # 只保留 top-k 的 token
    values, indices = torch.topk(logits, k)

    # 将其他 token 的概率设为 -inf
    filtered = torch.full_like(logits, float('-inf'))
    filtered[indices] = logits[indices]

    probs = softmax(filtered, dim=-1)
    return torch.multinomial(probs, 1)

5.3 Top-P(Nucleus)采样

动态选择概率累积和达到 P 的最小 token 集合:

def top_p_sampling(logits, p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    sorted_probs = softmax(sorted_logits, dim=-1)

    # 计算累积概率
    cumsum = torch.cumsum(sorted_probs, dim=-1)

    # 去掉累积超过 p 的 token
    sorted_probs[cumsum > p] = 0.0

    # 重归一化
    return torch.multinomial(sorted_probs / sorted_probs.sum(), 1)

5.4 策略对比

策略 多样性 可靠性 使用场景
Greedy(argmax) 极低 最高 事实性问答、翻译
Temperature (0.7) 一般对话
Top-K (50) 中高 创意写作
Top-P (0.9) 多样性要求高
Temperature + Top-P + Top-K 可调 可调 生产环境主流方案

六、推理优化技术

6.1 Flash Attention

Flash Attention 通过 tiling(分块)重计算 将注意力计算的显存占用从 O(n²) 降低到 O(n):

flowchart LR
    subgraph "标准 Attention"
        S1[Q,K] --> S2[Scores O()] --> S3[Softmax O()] --> S4[Output O()]
    end
    subgraph "Flash Attention"
        F1[分块 Q/K] --> F2[片上 SRAM 计算] --> F3[增量更新 Output]
        F2 -->|在线 softmax 重计算| F2
    end

核心技术:
1. 分块 Tiling: 将 Q/K/V 分成小块,每次只在 SRAM 中计算一小块
2. 重计算: 前向传播时不保留中间的注意力矩阵,反向传播时重新计算
3. 效率: 在 A100 上训练速度提升 2-4 倍,长序列效果更显著

6.2 量化(Quantization)

将模型权重从 FP16(16位浮点)压缩到更低精度:

精度 位宽 内存 7B 模型 推理速度 精度损失
FP16 16 14 GB 基准 0
INT8 8 7 GB 2x 非常小
INT4 4 3.5 GB 3-4x 轻微
NF4 4 3.5 GB 3-4x 几乎无
# 伪代码:量化过程
def quantize_to_int8(weights):
    # 计算缩放因子
    scale = weights.abs().max() / 127.0
    # 量化
    q_weights = (weights / scale).round().to(torch.int8)
    return q_weights, scale

# 反量化
def dequantize(q_weights, scale):
    return q_weights * scale

6.3 推测解码(Speculative Decoding)

使用一个小模型(草稿模型)快速生成多个候选 token,再用大模型一次性验证:

sequenceDiagram
    participant D as 草稿模型 ()
    participant T as 目标模型 ()

    D->>D: 快速生成 K 个候选 token
    D->>T: 提交候选序列
    T->>T: 并行验证所有候选
    T->>T: 发现第 3  token 概率低
    T->>D: 接受前 2 个,拒绝后续
    Note over T: 一次推理 = 接受多个 token

收益: 在不改变生成质量的前提下,推理延迟降低 2-3 倍。关键条件:草稿模型和目标模型的分布足够接近。

6.4 PagedAttention

vLLM 提出的内存管理方案,将 KV Cache 分页管理:

技术 内存占用 碎片 共享效率
连续分配 基线 60-80% 不支持
PagedAttention 减少 50%+ <10% 支持 Copy-on-Write

适用场景: 高并发 Serving 和长序列推理,vLLM 将此技术落地为生产级推理引擎。

七、总结

组件 角色 瓶颈 优化方向
Transformer Block 特征提取 逐层串行 模型蒸馏、层剪枝
Self-Attention 上下文建模 O(n²) 计算 Flash Attention、稀疏注意力
KV Cache 历史复用 线性增长 GQA、PagedAttention
采样策略 输出生成 逐 token 推测解码、并行解码
模型权重 存储参数 内存占用 量化(INT4/8)

LLM 推理优化正在从”能不能跑”走向”跑得快、跑得省”。理解这些底层机制,对于模型选型、部署成本估算和推理架构设计都至关重要。未来,随着长上下文(100K+ token)和多模态输入成为标配,这些优化技术还将继续演进。

© 版权声明
THE END
喜欢就支持一下吧
点赞10 分享
评论 抢沙发

请登录后发表评论

    暂无评论内容