Transformer 架构深度解析
AI 导读
Transformer 架构深度解析 Maurice | 灵阙学院 2026-02-27 一句话理解 Transformer Transformer 的本质是一个"注意力驱动的序列到序列映射器":给定输入序列中的每个位置,它通过注意力机制动态地从所有其他位置收集信息,而不像 RNN 那样被迫按顺序逐步传递。这使得它天然支持并行计算,且能捕获任意距离的依赖关系。 整体架构...
Transformer 架构深度解析
Maurice | 灵阙学院 2026-02-27
一句话理解 Transformer
Transformer 的本质是一个"注意力驱动的序列到序列映射器":给定输入序列中的每个位置,它通过注意力机制动态地从所有其他位置收集信息,而不像 RNN 那样被迫按顺序逐步传递。这使得它天然支持并行计算,且能捕获任意距离的依赖关系。
整体架构
┌────────────────────────────────────────────────────────────┐
│ Transformer 架构 │
├────────────────────────────────────────────────────────────┤
│ 原始 (2017) 现代变体 │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Encoder │ │ Decoder-Only│ GPT / LLaMA / │
│ │ + Decoder │ │ │ Qwen / Mistral │
│ └─────────────┘ └─────────────┘ │
│ ┌─────────────┐ │
│ │ Encoder-Only│ BERT / BGE │
│ └─────────────┘ │
│ ┌─────────────┐ │
│ │ Enc-Dec │ T5 / Whisper │
│ └─────────────┘ │
└────────────────────────────────────────────────────────────┘
| 架构 | 注意力类型 | 代表模型 | 典型任务 |
|---|---|---|---|
| Encoder-Only | 双向全注意力 | BERT, RoBERTa | 分类、NER、Embedding |
| Decoder-Only | 因果掩码注意力 | GPT, LLaMA, Qwen | 文本生成、对话、代码 |
| Encoder-Decoder | 编码器双向 + 解码器因果 | T5, BART, Whisper | 翻译、摘要、语音识别 |
当前大语言模型的主流是 Decoder-Only 架构,因为它在 scaling 时展现出最优的性能/效率比。
自注意力机制(Self-Attention)
直觉理解
想象你在读一句话:"小明把苹果给了小红,她很开心"。当你读到"她"时,你的大脑会自动回溯到"小红"来理解指代关系。自注意力做的就是这件事:让序列中的每个词都能"看到"并"关注"其他所有词。
Q/K/V 三元组
每个输入 token 被线性变换为三个向量:
- Query (Q):我在找什么信息?("她"的 Query 在问"谁是被指代的对象?")
- Key (K):我能提供什么信息?("小红"的 Key 在说"我是一个女性人名")
- Value (V):我的具体内容是什么?("小红"的 Value 携带完整语义)
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
除以 sqrt(d_k) 是为了防止点积过大导致 softmax 梯度消失。
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, V), attn_weights
多头注意力(Multi-Head Attention)
单一注意力只能学习一种"关注模式"。多头注意力将 Q/K/V 拆分到多个子空间,每个头独立计算注意力,最后拼接。不同头可以学到不同的语义关系(语法/语义/位置)。
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model=512, n_heads=8):
super().__init__()
self.n_heads, self.d_k = n_heads, d_model // n_heads
self.W_Q = torch.nn.Linear(d_model, d_model)
self.W_K = torch.nn.Linear(d_model, d_model)
self.W_V = torch.nn.Linear(d_model, d_model)
self.W_O = torch.nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, L, D = x.shape
Q = self.W_Q(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
out, _ = scaled_dot_product_attention(Q, K, V, mask)
return self.W_O(out.transpose(1, 2).contiguous().view(B, L, D))
位置编码(Positional Encoding)
注意力机制本身是位置无关的。必须显式注入位置信息,否则模型无法区分"猫追狗"和"狗追猫"。
| 方法 | 原理 | 外推能力 | 代表模型 | 年份 |
|---|---|---|---|---|
| 正弦余弦 | sin/cos 固定函数 | 弱 | 原始 Transformer | 2017 |
| 可学习绝对位置 | 可训练的位置向量 | 无 | GPT-2, BERT | 2019 |
| RoPE | 旋转矩阵编码相对位置 | 中(可插值扩展) | LLaMA, Qwen, Mistral | 2021 |
| ALiBi | 注意力分数加线性偏置 | 强(天然外推) | BLOOM, MPT | 2022 |
| YaRN | RoPE + 温度缩放 | 强(128K+) | LLaMA 3, Qwen 2.5 | 2023 |
RoPE 直觉
RoPE 不是给 token 加一个位置向量,而是对 Q 和 K 做旋转变换。位置 m 的向量被旋转 m * theta 角度。两个 token 的注意力分数只取决于它们的相对距离(旋转角度差),而不是绝对位置。
def precompute_rope(dim, max_seq_len, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, freqs)
return torch.cos(freqs), torch.sin(freqs)
def apply_rope(x, cos, sin):
x1, x2 = x[..., ::2], x[..., 1::2]
out = torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
return out.flatten(-2)
Layer Normalization
- Post-Norm(原始论文):训练不稳定,需 warmup
- Pre-Norm(现代标准):训练更稳定,GPT/LLaMA 的默认选择
- RMSNorm(LLaMA 引入):去掉均值中心化,计算量减少 ~15%
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) * self.weight
KV-Cache:推理加速的关键
自回归生成时,KV-Cache 将已计算的 K/V 缓存下来,新 token 只需计算自身的 Q 与缓存做注意力。计算量从 O(n^3d) 降到 O(n^2d)。
GQA (Grouped Query Attention):多个 Q 头共享同一组 K/V 头,KV-Cache 压缩 4-8 倍。
| 变体 | Q 头 | K/V 头 | Cache 大小 | 代表模型 |
|---|---|---|---|---|
| MHA | 32 | 32 | 基准 | GPT-3, LLaMA 1 |
| GQA | 32 | 8 | 1/4 | LLaMA 2/3, Qwen 2.5 |
| MQA | 32 | 1 | 1/32 | Falcon, PaLM |
Flash Attention
标准注意力的瓶颈不是计算而是内存带宽。Flash Attention 通过分块计算 + 在线 softmax,永远不将完整 n*n 矩阵写回 HBM。
| 版本 | 速度提升 | 内存节省 | 关键改进 |
|---|---|---|---|
| Flash Attention 1 | 2-4x | O(n) vs O(n^2) | 分块 + 在线 softmax |
| Flash Attention 2 | 额外 2x | 同上 | 并行化 + 减少非矩阵运算 |
| Flash Attention 3 | 额外 1.5x | 同上 | FP8 + Hopper 架构优化 |
# PyTorch 已内置,满足条件时自动调用 Flash Attention 内核
output = F.scaled_dot_product_attention(query, key, value, is_causal=True)
Mixture of Experts (MoE)
MoE 是扩展模型容量而不等比例增加计算量的方法。每层有 N 个 Expert(FFN),Router 为每个 token 选择 Top-K 个 Expert 执行,其余跳过。
| 模型 | 总参数 | 激活参数 | Expert 数 | Top-K |
|---|---|---|---|---|
| Mixtral 8x7B | 47B | ~13B | 8 | 2 |
| DeepSeek-V3 | 671B | ~37B | 256 | 8 |
| Qwen 2.5-MoE | 14B | ~2.7B | 60 | 4 |
核心挑战是负载均衡:辅助损失(auxiliary loss)和 Expert Choice routing 是主流解决方案。
现代 Decoder Block 数据流
Input -> RMSNorm -> GQA Attention (RoPE + KV-Cache + Causal Mask)
-> Residual Add -> RMSNorm -> SwiGLU FFN -> Residual Add -> Output
核心概念速查表
| 概念 | 一句话解释 | 为什么重要 |
|---|---|---|
| Self-Attention | 每个 token 动态关注所有其他 token | Transformer 的核心计算原语 |
| Multi-Head | 多个注意力子空间并行 | 捕获多种语义关系 |
| RoPE | 旋转位置编码 | 支持变长序列和位置外推 |
| KV-Cache | 缓存已计算的 K/V | 推理速度提升 100x+ |
| GQA | K/V 头共享 | KV-Cache 内存降 75% |
| Flash Attention | 分块计算避免 n*n 中间矩阵 | 训练/推理速度提升 2-4x |
| RMSNorm | 简化的 LayerNorm | 计算量减少,训练更稳定 |
| SwiGLU | 门控 FFN 激活函数 | 比 ReLU/GELU 效果更好 |
| MoE | 稀疏专家混合 | 总参数大但激活参数少 |
Maurice | maurice_wen@proton.me