Skip to content

3. Transformer 架构详解

导读:现代 LLM 都是 decoder-only Transformer 的变体。本章从自注意力机制的数学推导出发,逐步剖析 MHA / GQA / MQA、RoPE 位置编码、RMSNorm、SwiGLU、KV-Cache 等核心组件,给出每个设计的动机、推导与现代实现。

3.1 从 RNN 到 Transformer

在 2017 年之前,序列建模的主流是 RNN/LSTM/GRU。它们有两个根本缺陷:

  1. 顺序依赖:第 t 步必须等第 t1 步算完,无法并行
  2. 长程依赖弱:尽管 LSTM 的门控机制改善了梯度传播,但实测中超过 100 步就难以保持信息

Vaswani et al. (2017) "Attention Is All You Need" 提出 Transformer,用纯注意力机制替代 RNN 的递归结构,做到:

  • 完全并行(每个位置同时计算)
  • 任意位置之间一步直达
  • 复杂度 O(N2d)(vs RNN 的 O(Nd2),长序列时 RNN 更优,但短序列 Transformer 远占便宜)

现代 LLM(GPT、LLaMA、Mistral 等)都是decoder-only变体:去掉了原始 Transformer 中的 encoder-decoder 交叉注意力,只保留 decoder 的因果自注意力 + FFN。


3.2 自注意力的数学推导

3.2.1 输入表示

输入序列 x1,x2,,xT,每个 xt 经过 embedding 得到向量 xtRd。整个序列拼成矩阵:

X=(x1x2xT)RT×d

3.2.2 Q、K、V 投影

通过三个可学习的投影矩阵 WQ,WK,WVRd×dk

Q=XWQ,K=XWK,V=XWV

得到 Query, Key, Value 矩阵,均为 RT×dk

直觉

  • qt(第 t 行):当前 token 想"问"什么
  • kt:当前 token 提供什么"线索"
  • vt:当前 token 的"内容"

3.2.3 Scaled Dot-Product Attention

注意力分数矩阵:

S=QKdkRT×T

Sij 表示位置 i 对位置 j 的关注度(未归一化)。

为什么除 dk?设 q,k 各分量独立同分布于 N(0,1),则点积

qk=i=1dkqiki

均值为 0,方差为 dk。当 dk 大时(如 128),点积绝对值动辄达到 12811,softmax 会进入饱和区,梯度近乎为 0。除以 dk 把方差归一回 1,让 softmax 工作在敏感区。

接下来 row-wise softmax:

A=softmax(S)=softmax(QKdk)

Aij 是位置 i 对位置 j 的注意力权重,每行和为 1。

最后加权求和:

Attn(Q,K,V)=AV=softmax(QKdk)VRT×dk

3.2.4 因果掩码(Causal Mask)

decoder-only LLM 是自回归的,第 t 个 token 只能看到 t 的 token。在 softmax 之前加上掩码:

Mij={0jij>iA=softmax(QKdk+M)

由于 exp()=0,被 mask 的位置权重为 0。

3.2.5 复杂度分析

操作计算量显存
Q,K,V 投影3Td23Td
S=QKT2dkT2
softmaxT2T2
AVT2dkTdk
输出投影Td2Td

总计算量 Θ(T2d+Td2);当 T>d 时主导项是 T2d长上下文是 quadratic 瓶颈)。

总显存 Θ(T2)(注意力矩阵 A);当 T=32768,FP16 下单层就要 3276822=2 GB。

这就是 FlashAttention(下一章)要解决的问题——避免显式存储 T×T 的注意力矩阵


3.3 多头注意力 (MHA)

3.3.1 动机

单头注意力让模型用一个 dk 维子空间表达"关注什么"。但语言中的依赖关系是多种类的:句法依赖、语义关联、共指、远程指代——每种关系最好有自己的子空间。

3.3.2 公式

d 维空间分成 h 个头,每头维度 dh=d/h

headi=Attn(XWQ(i),XWK(i),XWV(i))MHA(X)=Concat(head1,,headh)WO

其中 WQ(i),WK(i),WV(i)Rd×dhWORd×d

实际实现中,把 h 个头的投影矩阵拼起来:WQ=[WQ(1),,WQ(h)]Rd×d,一次矩阵乘搞定,再 reshape 出头维度:

python
def mha_forward(x, w_q, w_k, w_v, w_o, n_heads):
    B, T, D = x.shape
    d_h = D // n_heads
    q = (x @ w_q).view(B, T, n_heads, d_h).transpose(1, 2)  # [B, h, T, d_h]
    k = (x @ w_k).view(B, T, n_heads, d_h).transpose(1, 2)
    v = (x @ w_v).view(B, T, n_heads, d_h).transpose(1, 2)
    scores = q @ k.transpose(-2, -1) / math.sqrt(d_h)        # [B, h, T, T]
    mask = torch.triu(torch.ones(T, T), diagonal=1).bool().to(x.device)
    scores = scores.masked_fill(mask, float("-inf"))
    attn = scores.softmax(-1)
    out = attn @ v                                            # [B, h, T, d_h]
    out = out.transpose(1, 2).reshape(B, T, D)
    return out @ w_o

3.3.3 头数与头维度

经验配置:

模型dhdh
GPT-2 small7681264
GPT-2 medium10241664
GPT-3 175B1228896128
LLaMA-2 7B409632128
LLaMA-2 70B819264128
LLaMA-3 8B409632128
LLaMA-3 405B16384128128

dh 几乎总是 64 或 128,与 GPU tensor core 对齐。


3.4 GQA 与 MQA:KV-Cache 压缩

3.4.1 推理瓶颈:KV-Cache 显存

自回归生成时,每次 decode 一个 token 需要重读所有历史 token 的 K、V。我们把它们缓存下来(KV-Cache),避免重复计算。但缓存本身占大量显存。

单层、batch=1、序列长度 sh 个头、头维度 dh、FP16 的 KV-Cache 大小:

MKV(layer)=2shdh2 B

(2 = K + V,2 B = FP16)

LLaMA-2 70B 推理 32K 上下文示例:

  • L=80,h=64,dh=128
  • 单 token KV:280641282=2.6 MB
  • 32K 上下文 + batch 1:80 GB

单卡 80GB H100 都装不下!

3.4.2 MQA (Multi-Query Attention)

Shazeer (2019) "Fast Transformer Decoding: One Write-Head is All You Need":

所有 Q 头共享同一组 K、V

headi=Attn(XWQ(i),XWK,XWV)

注意 WK,WV 没有上标 (i)。这样:

  • KV-Cache 缩小 h
  • 推理时 K、V 投影只算一次
  • 训练时 K、V 共享,loss 略有下降

PaLM、Falcon 用 MQA。但 70B+ 模型上 MQA 的 quality 下降比较明显。

3.4.3 GQA (Grouped Query Attention)

Ainslie et al. (2023) "GQA: Training Generalized Multi-Query Transformer Models":折中方案。

设 KV 头数 hkv,满足 1hkvhhkv|h。每 h/hkv 个 Q 头共享一组 K、V:

headi=Attn(XWQ(i),XWK(i/g),XWV(i/g)),g=h/hkv
极端情况等价
hkv=hMHA
hkv=1MQA
中间GQA

LLaMA-2 70B 用 h=64,hkv=8g=8),Mistral 7B 用 h=32,hkv=8,LLaMA-3 全系列用 hkv=8

3.4.4 KV-Cache 显存对比

LLaMA 类 70B 模型,32K 上下文,batch=1:

注意力变体hkvKV-Cache推理速度
MHA6480 GB1.0x
GQA-8810 GB2.5x
MQA11.25 GB3.0x

GQA-8 几乎不损失质量,是当前 7B-70B 模型的事实标准。

3.4.5 GQA 实现

python
def gqa_forward(x, w_q, w_kv, w_o, n_heads, n_kv_heads):
    B, T, D = x.shape
    d_h = D // n_heads
    g = n_heads // n_kv_heads
    q = (x @ w_q).view(B, T, n_heads, d_h).transpose(1, 2)
    kv = (x @ w_kv).view(B, T, 2 * n_kv_heads, d_h).transpose(1, 2)
    k, v = kv.chunk(2, dim=1)               # [B, h_kv, T, d_h]
    # 复制 K、V 到 h 头
    k = k.repeat_interleave(g, dim=1)        # [B, h, T, d_h]
    v = v.repeat_interleave(g, dim=1)
    scores = q @ k.transpose(-2, -1) / math.sqrt(d_h)
    # ... mask + softmax + 加权求和(同 MHA)

注:repeat_interleave 在 SDPA 内部由 FlashAttention 直接处理(不实际复制内存)。


3.5 RoPE 位置编码完整推导

3.5.1 为什么需要位置编码

自注意力本身对位置无感:交换输入序列的两个 token,输出顺序也跟着交换,但每个 token 看到的"周围信息"不变。这显然对语言不合适——猫追狗狗追猫 含义完全相反。

需要把位置信息注入到 Q 和 K。

3.5.2 早期方案

绝对位置编码 (sinusoidal)

原 Transformer:

PE(pos,2i)=sin(pos/100002i/d),PE(pos,2i+1)=cos(pos/100002i/d)

加到 embedding:xt=xt+PE(t)

缺点:

  • 位置和内容耦合,长度外推差
  • 注意力计算无法直接利用相对位置

相对位置编码 (T5)

在注意力分数上加偏置 bij

Sij=qikjdk+bij

外推稍好但额外学习参数。

3.5.3 RoPE 的目标

Su et al. (2021) "RoFormer" 提出旋转位置编码 (Rotary Position Embedding, RoPE)。目标:构造函数 f,使得

f(q,m),f(k,n)=g(q,k,mn)

即注意力分数仅依赖相对位置 mn,与绝对位置 m,n 无关。

3.5.4 二维情形的优雅解

d=2。把 q=(q1,q2)R2 视为复数 zq=q1+iq2k 同理。

定义:

f(q,m)=zqeimθ,f(k,n)=zkeinθ

其中 θ 是固定频率参数。复数内积的实部:

Re[f(q,m)f(k,n)]=Re[zqzkei(mn)θ]

确实只依赖 mn,目标达成。

3.5.5 旋转矩阵形式

复数乘法 zeiϕ 等价于 2D 旋转。把 f(q,m) 写成矩阵形式:

f(q,m)=R(mθ)q,R(ϕ)=(cosϕsinϕsinϕcosϕ)

注意:

  • R(ϕ) 是正交矩阵,保模长f(q,m)=q
  • R(mθ)R(nθ)=R((mn)θ)(旋转的相对性)

注意力分数:

qmR(mθ)R(nθ)kn=qmR((nm)θ)kn

只依赖 nm

3.5.6 扩展到 d 维

d 维向量切成 d/2 个 2D 分量,每对用不同频率旋转:

θi=100002(i1)/d,i=1,2,,d/2

频率从高到低(短波长到长波长)。整体旋转矩阵:

RΘ,md=diag(R(mθ1),R(mθ2),,R(mθd/2))Rd×d

这是块对角矩阵,每个 2×2 块独立旋转。

最终:

qmRoPE=RΘ,mdWQxm,knRoPE=RΘ,ndWKxn

注意力分数:

(qmRoPE)knRoPE=(WQxm)RΘ,nmd(WKxn)

仅依赖相对位置 nm,目标达成。

3.5.7 高效实现

显式构造 d×d 稀疏矩阵开销大。实践中用以下等价形式:

RoPE(x,m)=xcos(mΘ)+rotate_half(x)sin(mΘ)

其中:

  • Θ=(θ1,θ1,θ2,θ2,,θd/2,θd/2)Rd(每个频率重复 2 次)
  • rotate_half(x) = [-x[d/2:], x[:d/2]]

这是两个对称的实数向量乘法,可以高效融合到矩阵乘的输入侧。

LLaMA 风格实现:

python
def precompute_freqs_cis(dim, max_seq_len, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)             # [seq_len, dim/2]
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def apply_rotary_emb(xq, xk, freqs_cis):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
    return xq_out.type_as(xq), xk_out.type_as(xk)

torch.view_as_complex 把每两个实数看成一个复数,乘以预计算的旋转因子,再 view 回实数。

3.5.8 RoPE 的优秀性质

  1. 相对位置:天然支持长程依赖
  2. 保模长:不改变 Q、K 的范数,与 RMSNorm 兼容
  3. 线性外推(弱):训练 Tmax=4K 的模型可以勉强用到 8K,但效果会下降
  4. 频率分布:高频维度捕捉局部信息,低频维度捕捉远程依赖
  5. 可解释性:旋转角度直接对应位置距离

3.5.9 长上下文扩展

直接外推到训练长度之外,注意力分数会发散(高频维度旋转过快)。常见扩展方案:

位置插值 PI (Chen et al. 2023)

把所有频率统一缩小:

θi=θis,s=TnewTold

效果:把"位置 m"看成"位置 m/s"。需要少量微调。

NTK-aware Scaling (bloc97 2023)

根据维度自适应缩放,保留高频细节:

base=basesd/(d2)

不需要微调即可外推 4-8 倍。

YaRN (Peng et al. 2023)

更细致地分频段处理:

  • 高频维度:保持原频率
  • 中频:按 NTK 缩放
  • 低频:按 PI 缩放(线性插值)

加上 attention scale 1/t(其中 t 是温度),缓解长序列分布漂移。

LLaMA-3 的做法

LLaMA-3 把 RoPE base 从 10,000 提升到 500,000,相当于把所有频率降低 50 倍,配合后训练长度扩展到 128K,效果优秀。

DeepSeek-V3 上下文 128K 也用类似策略 + YaRN。


3.6 RMSNorm vs LayerNorm

3.6.1 LayerNorm

Ba et al. (2016):

μ=1di=1dxi,σ2=1di=1d(xiμ)2LN(x)=γxμσ2+ϵ+β

参数:γ,βRd。需要两次遍历向量(计算 μσ)。

3.6.2 RMSNorm

Zhang & Sennrich (2019):

RMS(x)=1di=1dxi2+ϵRMSNorm(x)=γxRMS(x)

参数:只有 γ,无 β(不平移)。一次遍历(求平方和)。

3.6.3 对比

性质LayerNormRMSNorm
减均值
平移参数 β
遍历次数21
参数量2dd
FLOPs5d3d
实测效果baseline持平或略好

为什么 RMSNorm 不减均值也能 work

  • 神经网络中,μ 通常已经接近 0(残差连接 + Pre-Norm 的统计性质)
  • 即使有 μ 漂移,权重 W 的列空间会自动 absorb 这个偏移
  • 减均值带来的 covariate shift 校准,对 LLM 收益不大

LLaMA、Mistral、Qwen、DeepSeek、Gemma、Mixtral 等几乎所有现代 LLM 都用 RMSNorm。

PyTorch 实现:

python
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.weight * norm

torch.rsqrt(x) = 1/sqrt(x) 是单条 GPU 指令,比 1 / torch.sqrt(x) 略快。

3.6.4 Pre-Norm vs Post-Norm

原 Transformer 用 Post-Norm

xl+1=LN(xl+Sublayer(xl))

Pre-Norm(现代主流):

xl+1=xl+Sublayer(LN(xl))

差异:

  • Post-Norm:残差路径上的信号经过非线性,深层难以训练,需要 warmup 才稳定
  • Pre-Norm:残差路径是恒等流,梯度直接往前传,深层稳定但实测最终性能略差

实际配置:

  • LLaMA、GPT-3、Mistral 用 Pre-Norm
  • 但有 warmup 后 Post-Norm 仍可训
  • DeepNet (2022) 提出 α-Post-Norm 训了 1000 层

3.6.5 Sandwich Norm 与 Post-Pre-Norm

  • Sandwich Norm (Ding et al. 2021):在 Sublayer 之前和之后都加 LN,即 x+LN(Sublayer(LN(x)))
  • Pre/Post 混合:DeepSeek-V3 在 attention 用 Pre-Norm,在 FFN 用 Post-Norm,提升训练稳定性

3.7 SwiGLU 激活

3.7.1 标准 FFN

原 Transformer 的 FFN:

FFN(x)=ReLU(xW1+b1)W2+b2

隐层维度 dff=4d(经验最佳)。

GPT 系列用 GELU 替代 ReLU:

GELU(x)=xΦ(x)0.5x(1+tanh[2/π(x+0.044715x3)])

GELU 在小负值处保留少量信号,比 ReLU 略好。

3.7.2 GLU 家族

Dauphin et al. (2016) 提出 GLU (Gated Linear Unit)

GLU(x)=σ(xW)(xV)

其中 σ 是 sigmoid。"门控":用 σ(xW) 决定 xV 的每个分量保留多少。

变体(用不同激活替代 sigmoid):

变体激活公式
ReGLUReLUmax(0,xW)(xV)
GeGLUGELUGELU(xW)(xV)
SwiGLUSwishSwishβ(xW)(xV)

其中 Swish (a.k.a. SiLU):

Swishβ(x)=xσ(βx),SiLU(x)=xσ(x)

LLaMA 等取 β=1,即 SiLU。

3.7.3 SwiGLU FFN

完整公式:

FFNSwiGLU(x)=(SiLU(xW1)(xW3))W2

含三个矩阵:

  • W1(gate):Rd×dff
  • W3(up):Rd×dff
  • W2(down):Rdff×d

为保持参数量与原 FFN(含 2 矩阵)持平,把 dff4d 降到 8d3,再向上取到 64 或 256 的倍数。

LLaMA-2 7B:d=40968×40963=10922.67,取 64 倍数得 dff=11008

3.7.4 为何 SwiGLU 更好

Shazeer (2020) "GLU Variants Improve Transformer" 在 T5 上做了广泛对比:

FFNPile PPLGlue Avg
ReLU FFN1.99783.80
GELU FFN1.98384.20
SwiGLU FFN1.94484.36
GeGLU FFN1.94284.12

SwiGLU/GeGLU 在 perplexity 和下游任务上稳定优于 ReLU/GELU。原因:

  • 门控引入了乘法非线性(标准 FFN 只有加法 + 单点非线性)
  • 信息流更灵活,每个隐层维度可以有"开关"

Shazeer 自己评论:"we offer no explanation as to why these architectures seem to work; we attribute their success, as all else, to divine benevolence."

3.7.5 PyTorch 实现

python
class FFN(nn.Module):
    def __init__(self, dim, hidden_dim, multiple_of=256):
        super().__init__()
        # 取 multiple_of 倍数
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # gate
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)  # up
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)  # down

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

F.silu(x) = x * F.sigmoid(x),PyTorch 内置。

3.7.6 FFN 计算量与显存

每 token 一次 FFN:

FLOPsFFN=2ddff36ddff=16d2FLOPsAttn4d2T+2T2d

短上下文时 FFN 是计算量大头(约 16d24d2+16d280%)。


3.8 KV-Cache 机制详解

3.8.1 自回归生成的两阶段

Prefill 阶段:处理 prompt,一次性算出所有位置的 K、V,缓存。

Decode 阶段:每生成一个 token:

  1. 用上一步的 token 算新的 q,k,v
  2. k,v append 到 cache
  3. q 与整个 cache 做注意力
  4. 输出下一 token
python
# 简化伪代码
class KVCache:
    def __init__(self, max_len, n_kv_heads, d_h, dtype, device):
        self.k = torch.zeros(1, n_kv_heads, max_len, d_h, dtype=dtype, device=device)
        self.v = torch.zeros(1, n_kv_heads, max_len, d_h, dtype=dtype, device=device)
        self.cur = 0

    def update(self, k_new, v_new):
        L = k_new.size(2)
        self.k[:, :, self.cur:self.cur+L] = k_new
        self.v[:, :, self.cur:self.cur+L] = v_new
        self.cur += L
        return self.k[:, :, :self.cur], self.v[:, :, :self.cur]

3.8.2 显存与访存分析

显存L2shkvdhbytes。LLaMA-2 70B (GQA-8) + 32K:10 GB。

访存瓶颈:每生成 1 token 需要从 HBM 读取整个 KV-Cache。

  • KV-Cache 大小 10 GB
  • A100 HBM 带宽 1.5 TB/s
  • 单 token decode 至少 10/1500=6.7 ms(仅 KV 读取)

实际更慢(还有计算、参数读取等),decode 阶段基本是 memory-bound

这就是为什么 GQA 大幅加速推理:KV 小 8 倍,访存也少 8 倍。

3.8.3 KV 量化与压缩

方法节省代价
FP8 KV2x几乎无损
INT8 KV2x微小精度下降
INT4 KV4x1-2% PPL 上升
KIVI 2-bit8x长上下文略损
MLA (DeepSeek)4x+几乎无损(架构级)

MLA (Multi-head Latent Attention, DeepSeek-V2)

不是事后量化,而是从架构设计上压缩。把 K、V 投影到一个低维 latent 向量 c,缓存 c 而非 K、V:

ct=WDKVxtRdc,dchdh

推理时再"解压":

kt=WUKct,vt=WUVct

DeepSeek-V2: d=5120, hdh=16384(标准 MHA),但 dc=512,KV-Cache 减小 32 倍。配合 RoPE 兼容设计,质量优于 GQA。

DeepSeek-V3 同样使用 MLA,KV-Cache 比 LLaMA-3 70B 小 5-7 倍。

3.8.4 PagedAttention (vLLM)

KV-Cache 在不同 batch 间长度差异大,连续分配会浪费。

PagedAttention (Kwon et al. 2023) 把 KV-Cache 分 (block),每页 16 token,按需分配,类似 OS 虚拟内存。

效果:

  • 显存利用率从 60-80% 提升到 96%+
  • 支持更高的并发 batch
  • vLLM 的核心优化

3.8.5 KV-Cache 复用

一些场景 KV-Cache 可以跨请求复用

  • 相同的 system prompt:所有用户共用 prefix KV
  • Beam search:多个候选共享 prefix
  • Speculative decoding:草稿模型与目标模型共享 KV

vLLM、TensorRT-LLM 都支持 prefix caching。


3.9 完整 Transformer Block

把所有组件拼起来,一个现代 LLaMA 风格的 Transformer block:

python
class LlamaBlock(nn.Module):
    def __init__(self, dim, n_heads, n_kv_heads, ffn_hidden, eps=1e-6):
        super().__init__()
        self.attn_norm = RMSNorm(dim, eps)
        self.attn = GQAAttention(dim, n_heads, n_kv_heads)
        self.ffn_norm = RMSNorm(dim, eps)
        self.ffn = FFN(dim, ffn_hidden)

    def forward(self, x, freqs_cis, kv_cache=None):
        h = x + self.attn(self.attn_norm(x), freqs_cis, kv_cache)
        out = h + self.ffn(self.ffn_norm(h))
        return out

整个模型:

python
class Llama(nn.Module):
    def __init__(self, vocab_size, dim, n_layers, n_heads, n_kv_heads, ffn_hidden, max_seq_len):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, dim)
        self.layers = nn.ModuleList([
            LlamaBlock(dim, n_heads, n_kv_heads, ffn_hidden)
            for _ in range(n_layers)
        ])
        self.norm = RMSNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)
        self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len)

    def forward(self, tokens, kv_caches=None):
        x = self.tok_emb(tokens)
        freqs_cis = self.freqs_cis[:tokens.size(1)]
        for i, layer in enumerate(self.layers):
            cache = kv_caches[i] if kv_caches else None
            x = layer(x, freqs_cis, cache)
        x = self.norm(x)
        return self.lm_head(x)

LLaMA-2 7B 配置:

  • dim=4096
  • n_layers=32
  • n_heads=32
  • n_kv_heads=32(MHA,因为 7B 不大)
  • ffn_hidden=11008
  • vocab_size=32000

LLaMA-2 70B:dim=8192, n_layers=80, n_heads=64, n_kv_heads=8(GQA), ffn_hidden=28672

参数量公式(忽略 bias 和 norm 的小项):

NVd+L(4ddkv+3ddff)+dV

其中 dkv=d(1+2hkv/h)/2(Q 全维 + K, V 缩小)。


3.10 现代变体一瞥

3.10.1 Sliding Window Attention (Mistral)

只看最近 w 个 token:

Attnw(qi,kj)={normaliwjielse
  • 显存 O(Tw) 而非 O(T2)
  • 长序列每层只见 w 个邻居,但多层叠加后感受野 Lw(类似 CNN)
  • Mistral-7B:w=4096,32 层 → 理论感受野 130K

3.10.2 Mixture of Experts (MoE)

把 FFN 替换为 MoE 层(第 5 章详谈)。

3.10.3 Mamba / SSM

State Space Model 架构,替代注意力,复杂度 O(T)。但目前主流仍是 Transformer。

3.10.4 Parallel Block (Falcon, GPT-J)

把 attention 和 FFN 并行:

xl+1=xl+Attn(LN(xl))+FFN(LN(xl))

少一次 LN,反向传播更并行。


3.11 本章小结

本章拆解了现代 LLM 的核心构件:

  1. 自注意力:Q、K、V 投影 + scaled dot-product softmax + 加权求和;除 dk 防止饱和;因果 mask 实现自回归。
  2. MHA → GQA → MQA:Q 头不变、KV 头减少,KV-Cache 压缩 8 倍以上。LLaMA-3 全系列用 GQA-8。
  3. RoPE:通过旋转矩阵注入相对位置,频率 θi=100002(i1)/d;可外推(PI / NTK / YaRN),LLaMA-3 用 base=500K 支持 128K。
  4. RMSNorm:去掉 LayerNorm 的均值减法,少一倍参数,速度更快,效果持平。
  5. SwiGLU:FFN 加门控,dff=8d/3,质量稳定优于 ReLU/GELU。
  6. KV-Cache:自回归推理的访存瓶颈,GQA / MLA / 量化是优化方向。

下一章我们讨论如何让标准注意力本身在 GPU 上跑得更快——FlashAttention。


3.12 思考题

  1. GQA 退化分析:当 hkv=h 时 GQA 等价于 MHA,当 hkv=1 时等价于 MQA。请定量分析 KV-Cache 显存、推理 FLOPs、训练 quality 三者随 hkv 变化的曲线(用 LLaMA-2 70B 的尺度),说明为什么 hkv=8 是甜蜜点。

  2. RoPE 外推数学:标准 RoPE 训练长度 Tmax=4096,base = 10000。若不做任何改动直接推到 T=32768,最高频维度(θ1=1)的旋转角度变到多少?为什么会"周期回卷"导致注意力失效?请用三角函数证明 NTK-aware Scaling 通过 base=basesd/(d2) 能保留高频精度。

  3. SwiGLU 参数量推导:标准 ReLU FFN dff=4d,参数量 8d2。SwiGLU FFN 含 3 个矩阵,要保持总参数量 8d2,应取 dff?。LLaMA-2 7B 实际取 11008,相对于 d=40968d/3 的多少倍?为何会取这个值?

基于 MIT 协议发布