3. Transformer 架构详解
导读:现代 LLM 都是 decoder-only Transformer 的变体。本章从自注意力机制的数学推导出发,逐步剖析 MHA / GQA / MQA、RoPE 位置编码、RMSNorm、SwiGLU、KV-Cache 等核心组件,给出每个设计的动机、推导与现代实现。
3.1 从 RNN 到 Transformer
在 2017 年之前,序列建模的主流是 RNN/LSTM/GRU。它们有两个根本缺陷:
- 顺序依赖:第
步必须等第 步算完,无法并行 - 长程依赖弱:尽管 LSTM 的门控机制改善了梯度传播,但实测中超过 100 步就难以保持信息
Vaswani et al. (2017) "Attention Is All You Need" 提出 Transformer,用纯注意力机制替代 RNN 的递归结构,做到:
- 完全并行(每个位置同时计算)
- 任意位置之间一步直达
- 复杂度
(vs RNN 的 ,长序列时 RNN 更优,但短序列 Transformer 远占便宜)
现代 LLM(GPT、LLaMA、Mistral 等)都是decoder-only变体:去掉了原始 Transformer 中的 encoder-decoder 交叉注意力,只保留 decoder 的因果自注意力 + FFN。
3.2 自注意力的数学推导
3.2.1 输入表示
输入序列
3.2.2 Q、K、V 投影
通过三个可学习的投影矩阵
得到 Query, Key, Value 矩阵,均为
直觉:
(第 行):当前 token 想"问"什么 :当前 token 提供什么"线索" :当前 token 的"内容"
3.2.3 Scaled Dot-Product Attention
注意力分数矩阵:
为什么除
均值为 0,方差为
接下来 row-wise softmax:
最后加权求和:
3.2.4 因果掩码(Causal Mask)
decoder-only LLM 是自回归的,第
由于
3.2.5 复杂度分析
| 操作 | 计算量 | 显存 |
|---|---|---|
| softmax | ||
| 输出投影 |
总计算量
总显存
这就是 FlashAttention(下一章)要解决的问题——避免显式存储
3.3 多头注意力 (MHA)
3.3.1 动机
单头注意力让模型用一个
3.3.2 公式
把
其中
实际实现中,把
def mha_forward(x, w_q, w_k, w_v, w_o, n_heads):
B, T, D = x.shape
d_h = D // n_heads
q = (x @ w_q).view(B, T, n_heads, d_h).transpose(1, 2) # [B, h, T, d_h]
k = (x @ w_k).view(B, T, n_heads, d_h).transpose(1, 2)
v = (x @ w_v).view(B, T, n_heads, d_h).transpose(1, 2)
scores = q @ k.transpose(-2, -1) / math.sqrt(d_h) # [B, h, T, T]
mask = torch.triu(torch.ones(T, T), diagonal=1).bool().to(x.device)
scores = scores.masked_fill(mask, float("-inf"))
attn = scores.softmax(-1)
out = attn @ v # [B, h, T, d_h]
out = out.transpose(1, 2).reshape(B, T, D)
return out @ w_o3.3.3 头数与头维度
经验配置:
| 模型 | |||
|---|---|---|---|
| GPT-2 small | 768 | 12 | 64 |
| GPT-2 medium | 1024 | 16 | 64 |
| GPT-3 175B | 12288 | 96 | 128 |
| LLaMA-2 7B | 4096 | 32 | 128 |
| LLaMA-2 70B | 8192 | 64 | 128 |
| LLaMA-3 8B | 4096 | 32 | 128 |
| LLaMA-3 405B | 16384 | 128 | 128 |
3.4 GQA 与 MQA:KV-Cache 压缩
3.4.1 推理瓶颈:KV-Cache 显存
自回归生成时,每次 decode 一个 token 需要重读所有历史 token 的 K、V。我们把它们缓存下来(KV-Cache),避免重复计算。但缓存本身占大量显存。
单层、batch=1、序列长度
(2 = K + V,2 B = FP16)
LLaMA-2 70B 推理 32K 上下文示例:
- 单 token KV:
MB - 32K 上下文 + batch 1:
GB
单卡 80GB H100 都装不下!
3.4.2 MQA (Multi-Query Attention)
Shazeer (2019) "Fast Transformer Decoding: One Write-Head is All You Need":
所有 Q 头共享同一组 K、V:
注意
- KV-Cache 缩小
倍 - 推理时 K、V 投影只算一次
- 训练时 K、V 共享,loss 略有下降
PaLM、Falcon 用 MQA。但 70B+ 模型上 MQA 的 quality 下降比较明显。
3.4.3 GQA (Grouped Query Attention)
Ainslie et al. (2023) "GQA: Training Generalized Multi-Query Transformer Models":折中方案。
设 KV 头数
| 极端情况 | 等价 |
|---|---|
| MHA | |
| MQA | |
| 中间 | GQA |
LLaMA-2 70B 用
3.4.4 KV-Cache 显存对比
LLaMA 类 70B 模型,32K 上下文,batch=1:
| 注意力变体 | KV-Cache | 推理速度 | |
|---|---|---|---|
| MHA | 64 | 80 GB | 1.0x |
| GQA-8 | 8 | 10 GB | 2.5x |
| MQA | 1 | 1.25 GB | 3.0x |
GQA-8 几乎不损失质量,是当前 7B-70B 模型的事实标准。
3.4.5 GQA 实现
def gqa_forward(x, w_q, w_kv, w_o, n_heads, n_kv_heads):
B, T, D = x.shape
d_h = D // n_heads
g = n_heads // n_kv_heads
q = (x @ w_q).view(B, T, n_heads, d_h).transpose(1, 2)
kv = (x @ w_kv).view(B, T, 2 * n_kv_heads, d_h).transpose(1, 2)
k, v = kv.chunk(2, dim=1) # [B, h_kv, T, d_h]
# 复制 K、V 到 h 头
k = k.repeat_interleave(g, dim=1) # [B, h, T, d_h]
v = v.repeat_interleave(g, dim=1)
scores = q @ k.transpose(-2, -1) / math.sqrt(d_h)
# ... mask + softmax + 加权求和(同 MHA)注:repeat_interleave 在 SDPA 内部由 FlashAttention 直接处理(不实际复制内存)。
3.5 RoPE 位置编码完整推导
3.5.1 为什么需要位置编码
自注意力本身对位置无感:交换输入序列的两个 token,输出顺序也跟着交换,但每个 token 看到的"周围信息"不变。这显然对语言不合适——猫追狗 和 狗追猫 含义完全相反。
需要把位置信息注入到 Q 和 K。
3.5.2 早期方案
绝对位置编码 (sinusoidal)
原 Transformer:
加到 embedding:
缺点:
- 位置和内容耦合,长度外推差
- 注意力计算无法直接利用相对位置
相对位置编码 (T5)
在注意力分数上加偏置
外推稍好但额外学习参数。
3.5.3 RoPE 的目标
Su et al. (2021) "RoFormer" 提出旋转位置编码 (Rotary Position Embedding, RoPE)。目标:构造函数
即注意力分数仅依赖相对位置
3.5.4 二维情形的优雅解
设
定义:
其中
确实只依赖
3.5.5 旋转矩阵形式
复数乘法
注意:
是正交矩阵,保模长: (旋转的相对性)
注意力分数:
只依赖
3.5.6 扩展到 d 维
把
频率从高到低(短波长到长波长)。整体旋转矩阵:
这是块对角矩阵,每个
最终:
注意力分数:
仅依赖相对位置
3.5.7 高效实现
显式构造
其中:
(每个频率重复 2 次) rotate_half(x) = [-x[d/2:], x[:d/2]]
这是两个对称的实数向量乘法,可以高效融合到矩阵乘的输入侧。
LLaMA 风格实现:
def precompute_freqs_cis(dim, max_seq_len, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, freqs) # [seq_len, dim/2]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def apply_rotary_emb(xq, xk, freqs_cis):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)torch.view_as_complex 把每两个实数看成一个复数,乘以预计算的旋转因子,再 view 回实数。
3.5.8 RoPE 的优秀性质
- 相对位置:天然支持长程依赖
- 保模长:不改变 Q、K 的范数,与 RMSNorm 兼容
- 线性外推(弱):训练
的模型可以勉强用到 8K,但效果会下降 - 频率分布:高频维度捕捉局部信息,低频维度捕捉远程依赖
- 可解释性:旋转角度直接对应位置距离
3.5.9 长上下文扩展
直接外推到训练长度之外,注意力分数会发散(高频维度旋转过快)。常见扩展方案:
位置插值 PI (Chen et al. 2023)
把所有频率统一缩小:
效果:把"位置
NTK-aware Scaling (bloc97 2023)
根据维度自适应缩放,保留高频细节:
不需要微调即可外推 4-8 倍。
YaRN (Peng et al. 2023)
更细致地分频段处理:
- 高频维度:保持原频率
- 中频:按 NTK 缩放
- 低频:按 PI 缩放(线性插值)
加上 attention scale
LLaMA-3 的做法
LLaMA-3 把 RoPE base 从 10,000 提升到 500,000,相当于把所有频率降低 50 倍,配合后训练长度扩展到 128K,效果优秀。
DeepSeek-V3 上下文 128K 也用类似策略 + YaRN。
3.6 RMSNorm vs LayerNorm
3.6.1 LayerNorm
Ba et al. (2016):
参数:
3.6.2 RMSNorm
Zhang & Sennrich (2019):
参数:只有
3.6.3 对比
| 性质 | LayerNorm | RMSNorm |
|---|---|---|
| 减均值 | 是 | 否 |
| 平移参数 | 有 | 无 |
| 遍历次数 | 2 | 1 |
| 参数量 | ||
| FLOPs | ||
| 实测效果 | baseline | 持平或略好 |
为什么 RMSNorm 不减均值也能 work?
- 神经网络中,
通常已经接近 0(残差连接 + Pre-Norm 的统计性质) - 即使有
漂移,权重 的列空间会自动 absorb 这个偏移 - 减均值带来的 covariate shift 校准,对 LLM 收益不大
LLaMA、Mistral、Qwen、DeepSeek、Gemma、Mixtral 等几乎所有现代 LLM 都用 RMSNorm。
PyTorch 实现:
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return self.weight * normtorch.rsqrt(x) = 1/sqrt(x) 是单条 GPU 指令,比 1 / torch.sqrt(x) 略快。
3.6.4 Pre-Norm vs Post-Norm
原 Transformer 用 Post-Norm:
Pre-Norm(现代主流):
差异:
- Post-Norm:残差路径上的信号经过非线性,深层难以训练,需要 warmup 才稳定
- Pre-Norm:残差路径是恒等流,梯度直接往前传,深层稳定但实测最终性能略差
实际配置:
- LLaMA、GPT-3、Mistral 用 Pre-Norm
- 但有 warmup 后 Post-Norm 仍可训
- DeepNet (2022) 提出
-Post-Norm 训了 1000 层
3.6.5 Sandwich Norm 与 Post-Pre-Norm
- Sandwich Norm (Ding et al. 2021):在 Sublayer 之前和之后都加 LN,即
- Pre/Post 混合:DeepSeek-V3 在 attention 用 Pre-Norm,在 FFN 用 Post-Norm,提升训练稳定性
3.7 SwiGLU 激活
3.7.1 标准 FFN
原 Transformer 的 FFN:
隐层维度
GPT 系列用 GELU 替代 ReLU:
GELU 在小负值处保留少量信号,比 ReLU 略好。
3.7.2 GLU 家族
Dauphin et al. (2016) 提出 GLU (Gated Linear Unit):
其中
变体(用不同激活替代 sigmoid):
| 变体 | 激活 | 公式 |
|---|---|---|
| ReGLU | ReLU | |
| GeGLU | GELU | |
| SwiGLU | Swish |
其中 Swish (a.k.a. SiLU):
LLaMA 等取
3.7.3 SwiGLU FFN
完整公式:
含三个矩阵:
(gate): (up): (down):
为保持参数量与原 FFN(含 2 矩阵)持平,把
LLaMA-2 7B:
3.7.4 为何 SwiGLU 更好
Shazeer (2020) "GLU Variants Improve Transformer" 在 T5 上做了广泛对比:
| FFN | Pile PPL | Glue Avg |
|---|---|---|
| ReLU FFN | 1.997 | 83.80 |
| GELU FFN | 1.983 | 84.20 |
| SwiGLU FFN | 1.944 | 84.36 |
| GeGLU FFN | 1.942 | 84.12 |
SwiGLU/GeGLU 在 perplexity 和下游任务上稳定优于 ReLU/GELU。原因:
- 门控引入了乘法非线性(标准 FFN 只有加法 + 单点非线性)
- 信息流更灵活,每个隐层维度可以有"开关"
Shazeer 自己评论:"we offer no explanation as to why these architectures seem to work; we attribute their success, as all else, to divine benevolence."
3.7.5 PyTorch 实现
class FFN(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of=256):
super().__init__()
# 取 multiple_of 倍数
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # gate
self.w3 = nn.Linear(dim, hidden_dim, bias=False) # up
self.w2 = nn.Linear(hidden_dim, dim, bias=False) # down
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))F.silu(x) = x * F.sigmoid(x),PyTorch 内置。
3.7.6 FFN 计算量与显存
每 token 一次 FFN:
短上下文时 FFN 是计算量大头(约
3.8 KV-Cache 机制详解
3.8.1 自回归生成的两阶段
Prefill 阶段:处理 prompt,一次性算出所有位置的 K、V,缓存。
Decode 阶段:每生成一个 token:
- 用上一步的 token 算新的
- 把
append 到 cache - 用
与整个 cache 做注意力 - 输出下一 token
# 简化伪代码
class KVCache:
def __init__(self, max_len, n_kv_heads, d_h, dtype, device):
self.k = torch.zeros(1, n_kv_heads, max_len, d_h, dtype=dtype, device=device)
self.v = torch.zeros(1, n_kv_heads, max_len, d_h, dtype=dtype, device=device)
self.cur = 0
def update(self, k_new, v_new):
L = k_new.size(2)
self.k[:, :, self.cur:self.cur+L] = k_new
self.v[:, :, self.cur:self.cur+L] = v_new
self.cur += L
return self.k[:, :, :self.cur], self.v[:, :, :self.cur]3.8.2 显存与访存分析
显存:
访存瓶颈:每生成 1 token 需要从 HBM 读取整个 KV-Cache。
- KV-Cache 大小
GB - A100 HBM 带宽 1.5 TB/s
- 单 token decode 至少
ms(仅 KV 读取)
实际更慢(还有计算、参数读取等),decode 阶段基本是 memory-bound。
这就是为什么 GQA 大幅加速推理:KV 小 8 倍,访存也少 8 倍。
3.8.3 KV 量化与压缩
| 方法 | 节省 | 代价 |
|---|---|---|
| FP8 KV | 2x | 几乎无损 |
| INT8 KV | 2x | 微小精度下降 |
| INT4 KV | 4x | 1-2% PPL 上升 |
| KIVI 2-bit | 8x | 长上下文略损 |
| MLA (DeepSeek) | 4x+ | 几乎无损(架构级) |
MLA (Multi-head Latent Attention, DeepSeek-V2)
不是事后量化,而是从架构设计上压缩。把 K、V 投影到一个低维 latent 向量
推理时再"解压":
DeepSeek-V2:
DeepSeek-V3 同样使用 MLA,KV-Cache 比 LLaMA-3 70B 小 5-7 倍。
3.8.4 PagedAttention (vLLM)
KV-Cache 在不同 batch 间长度差异大,连续分配会浪费。
PagedAttention (Kwon et al. 2023) 把 KV-Cache 分页 (block),每页 16 token,按需分配,类似 OS 虚拟内存。
效果:
- 显存利用率从 60-80% 提升到 96%+
- 支持更高的并发 batch
- vLLM 的核心优化
3.8.5 KV-Cache 复用
一些场景 KV-Cache 可以跨请求复用:
- 相同的 system prompt:所有用户共用 prefix KV
- Beam search:多个候选共享 prefix
- Speculative decoding:草稿模型与目标模型共享 KV
vLLM、TensorRT-LLM 都支持 prefix caching。
3.9 完整 Transformer Block
把所有组件拼起来,一个现代 LLaMA 风格的 Transformer block:
class LlamaBlock(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads, ffn_hidden, eps=1e-6):
super().__init__()
self.attn_norm = RMSNorm(dim, eps)
self.attn = GQAAttention(dim, n_heads, n_kv_heads)
self.ffn_norm = RMSNorm(dim, eps)
self.ffn = FFN(dim, ffn_hidden)
def forward(self, x, freqs_cis, kv_cache=None):
h = x + self.attn(self.attn_norm(x), freqs_cis, kv_cache)
out = h + self.ffn(self.ffn_norm(h))
return out整个模型:
class Llama(nn.Module):
def __init__(self, vocab_size, dim, n_layers, n_heads, n_kv_heads, ffn_hidden, max_seq_len):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, dim)
self.layers = nn.ModuleList([
LlamaBlock(dim, n_heads, n_kv_heads, ffn_hidden)
for _ in range(n_layers)
])
self.norm = RMSNorm(dim)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len)
def forward(self, tokens, kv_caches=None):
x = self.tok_emb(tokens)
freqs_cis = self.freqs_cis[:tokens.size(1)]
for i, layer in enumerate(self.layers):
cache = kv_caches[i] if kv_caches else None
x = layer(x, freqs_cis, cache)
x = self.norm(x)
return self.lm_head(x)LLaMA-2 7B 配置:
(MHA,因为 7B 不大)
LLaMA-2 70B:
参数量公式(忽略 bias 和 norm 的小项):
其中
3.10 现代变体一瞥
3.10.1 Sliding Window Attention (Mistral)
只看最近
- 显存
而非 - 长序列每层只见
个邻居,但多层叠加后感受野 (类似 CNN) - Mistral-7B:
,32 层 → 理论感受野 130K
3.10.2 Mixture of Experts (MoE)
把 FFN 替换为 MoE 层(第 5 章详谈)。
3.10.3 Mamba / SSM
State Space Model 架构,替代注意力,复杂度
3.10.4 Parallel Block (Falcon, GPT-J)
把 attention 和 FFN 并行:
少一次 LN,反向传播更并行。
3.11 本章小结
本章拆解了现代 LLM 的核心构件:
- 自注意力:Q、K、V 投影 + scaled dot-product softmax + 加权求和;除
防止饱和;因果 mask 实现自回归。 - MHA → GQA → MQA:Q 头不变、KV 头减少,KV-Cache 压缩 8 倍以上。LLaMA-3 全系列用 GQA-8。
- RoPE:通过旋转矩阵注入相对位置,频率
;可外推(PI / NTK / YaRN),LLaMA-3 用 base=500K 支持 128K。 - RMSNorm:去掉 LayerNorm 的均值减法,少一倍参数,速度更快,效果持平。
- SwiGLU:FFN 加门控,
,质量稳定优于 ReLU/GELU。 - KV-Cache:自回归推理的访存瓶颈,GQA / MLA / 量化是优化方向。
下一章我们讨论如何让标准注意力本身在 GPU 上跑得更快——FlashAttention。
3.12 思考题
GQA 退化分析:当
时 GQA 等价于 MHA,当 时等价于 MQA。请定量分析 KV-Cache 显存、推理 FLOPs、训练 quality 三者随 变化的曲线(用 LLaMA-2 70B 的尺度),说明为什么 是甜蜜点。 RoPE 外推数学:标准 RoPE 训练长度
,base = 10000。若不做任何改动直接推到 ,最高频维度( )的旋转角度变到多少?为什么会"周期回卷"导致注意力失效?请用三角函数证明 NTK-aware Scaling 通过 能保留高频精度。 SwiGLU 参数量推导:标准 ReLU FFN
,参数量 。SwiGLU FFN 含 3 个矩阵,要保持总参数量 ,应取 。LLaMA-2 7B 实际取 11008,相对于 是 的多少倍?为何会取这个值?