Skip to content

7. 训练工程

导读:算法、数据、并行架构都到位了,最后一公里是训练工程——精度选择、梯度累积、优化器、学习率调度、scaling law、稳定性管理。这些细节决定模型能否高质量地完成万亿 token 的训练。本章把这些工程实践系统化,给出公式、对比与实战建议。

7.1 混合精度训练

7.1.1 浮点格式回顾

格式总位指数位尾数位动态范围精度 (ULP)用途
FP323282310381038223默认 / 主权重
FP16165106×10565504210早期 mixed precision
BF161687同 FP3227现代默认
TF3219810同 FP32210A100 张量核
FP8 E4M3843±44823H100 前向
FP8 E5M2852±5734422H100 反向
FP4 (实验)421极窄极差推理量化

7.1.2 FP16 + Loss Scaling

Micikevicius et al. (2017) "Mixed Precision Training" 是混合精度的开山之作。

问题:FP16 表示范围小,反向传播中小梯度(如 107)会下溢为 0。

方案:把 loss 乘上一个标量 S(如 S=215=32768),梯度也按比例放大;优化器更新前再除回 S

python
S = 2**15
loss = forward(...)
scaled_loss = loss * S
scaled_loss.backward()           # 梯度放大 S 倍,避免下溢
for p in model.parameters():
    p.grad.data /= S              # 还原
optimizer.step()

主权重 (master weights) 保持 FP32:模型副本 FP16 用于前向/反向,FP32 主副本接收优化器更新(避免精度损失累积)。

7.1.3 动态 Loss Scaling

固定 S 不够鲁棒:早期 S 可能过大溢出,后期可能过小。动态 loss scaling

python
# 启动 S = 2^15
S = 2**15
growth_count = 0
for step:
    scaled_loss = loss * S
    scaled_loss.backward()
    if any(grad has NaN or Inf):
        S = S / 2          # 缩 S
        skip_optimizer_step()
        growth_count = 0
    else:
        optimizer.step()
        growth_count += 1
        if growth_count >= N:    # 通常 N=2000
            S = S * 2      # 扩 S
            growth_count = 0

PyTorch torch.cuda.amp.GradScaler 内置实现。

7.1.4 BF16:现代默认

Brain Floating Point 16-bit (BF16) 由 Google Brain 提出:

  • 指数位 8(同 FP32):动态范围一致,无需 loss scaling
  • 尾数位 7:精度低于 FP16(27 vs 210),但对深度学习已够

A100 / H100 / TPU 全面支持 BF16 张量核,性能与 FP16 相同。

优势

  • 训练稳定(无溢出/下溢)
  • 代码简单(无需 GradScaler)
  • 主权重可以保持 BF16 也可以 FP32(实验显示 BF16 主权重也 OK,省内存)

LLaMA、Mistral、Gemini、Qwen 等几乎所有现代大模型都用 BF16 训练。

注意:BF16 精度低,长时间训练可能累积误差。Google 在 Gemini 训练中遇到过 BF16 → FP32 master weights 的精度差异问题。

7.1.5 FP8 训练

Hopper 架构(H100/H800/B100)引入 FP8 张量核,FLOPs 是 BF16 的 2 倍。

两种 FP8 格式:

  • E4M3:精度优先(4 指数 + 3 尾数),动态范围 ±448,前向用
  • E5M2:动态范围优先(5 指数 + 2 尾数),范围 ±57344,反向用

Per-tensor Scaling

由于 FP8 范围小,每个张量需要一个 scale factor s 把数值规范到 [448,448]

xFP8=round(xFP32s)

GEMM 后:

yFP32=sxsw(xFP8wFP8)FP32 累加

scale 可以静态(每层一个)或动态(每 tensor 实时统计)。

Per-channel / Per-block Scaling

更细粒度的 scaling 进一步降低量化误差:

  • Per-channel:每行/列一个 scale
  • Per-block (DeepSeek-V3):1×128 或 128×128 块一个 scale

DeepSeek-V3 用 per-block FP8 GEMM:

  • 输入 X 按 1×128 行块缩放
  • 权重 W 按 128×128 块缩放
  • 输出累加用 BF16

这种细粒度让 FP8 训练 loss 与 BF16 几乎完全重合。

FP8 训练实测

DeepSeek-V3 是首个公开的 FP8 大规模预训练:

  • 14.8T token 训练
  • 关键模块(embedding、output、部分 attention)保持 BF16 兜底
  • MFU 从 BF16 的 ~45% 提升到 ~60%
  • 训练成本降低 1/3

NVIDIA Transformer Engine 库提供 FP8 训练支持:

python
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

fp8_recipe = recipe.DelayedScaling(
    fp8_format=recipe.Format.HYBRID,  # E4M3 fwd, E5M2 bwd
    margin=0,
    interval=1,
    amax_history_len=16,
)

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    output = te_linear(input)

7.1.6 精度选择决策

场景推荐
A100 / 一般 GPUBF16 + FP32 master
H100 / H800 + 大模型FP8 (per-block) + BF16 master
老硬件 (V100)FP16 + Loss Scaling
推理量化INT8 / FP8 / INT4
微调BF16(FP8 微调易不稳)

7.2 梯度累积与梯度检查点

7.2.1 梯度累积

目标:用小 micro-batch 模拟大 effective batch。

effective batch=micro-batch×accum_steps×world_size
python
optimizer.zero_grad()
for k in range(accum_steps):
    micro_batch = next_batch()
    loss = forward(micro_batch) / accum_steps  # 平均
    loss.backward()                             # 梯度累加
optimizer.step()

为什么除以 accum_steps:保持梯度量级与不累积时一致(否则 effective lr 会随 accum_steps 变大)。

7.2.2 大 batch 的意义

实证(Kaplan 2020、McCandlish 2018):训练大模型需要 critical batch size Bcrit 之上的 batch 才能保持收敛效率。

LLM 的 critical batch 极大:

  • LLaMA-2 7B: ~4M token
  • LLaMA-2 70B: ~16M token
  • GPT-4: 几十 M token(推测)

DP-only 难以达到这个量级(每卡 micro-batch + DP rank 总数有限),梯度累积是关键。

7.2.3 梯度检查点

激活内存压缩,详见第 6 章 §6.8。要点:

  • 显存:O(L)
  • 计算:增加 ~33%
  • Selective recomputation:仅重算便宜的(softmax),增加 ~5%

7.2.4 微批与梯度累积的权衡

配置优点缺点
大 micro-batch + 少 accum计算密度高(GPU 用满)显存压力大
小 micro-batch + 多 accum显存友好accum 间无通信,但每 micro-batch 都要计算

最优配置通过实测确定。LLaMA-3 405B:micro_batch=1, accum=124, DP=128 → effective batch ≈ 16M。


7.3 优化器

7.3.1 AdamW

Loshchilov & Hutter (2017) "Decoupled Weight Decay Regularization"。

公式

mt=β1mt1+(1β1)gtvt=β2vt1+(1β2)gt2m^t=mt1β1t,v^t=vt1β2tθt=θt1η(m^tv^t+ϵ+λθt1)

关键差异

AdamW vs Adam:weight decay λθ 解耦于 Adam 主体,不参与一阶/二阶动量。原 Adam 把 L2 正则塞进梯度 g+=λθ,再走 Adam 流程,相当于 weight decay 也被 momentum 平滑——这会导致大权重的 decay 效果被弱化。

AdamW 直接在更新中减 ηλθ,效果更纯粹。

LLM 默认超参

模型β1β2ϵλ
GPT-30.90.951080.1
LLaMA-1/2/30.90.951050.1
Chinchilla0.90.951080.1
DeepSeek-V30.90.951080.1
一般 ML0.90.9991080.01

LLM 普遍用 β2=0.95(不是默认的 0.999),原因:

  • β2 越大,二阶矩估计越平滑,但适应慢
  • 大 batch + 长 horizon 训练,0.95 更敏感地跟踪当前梯度
  • 0.999 在 LLM 上偶发收敛慢

Bias / LayerNorm 参数不加 weight decay(标量参数对正则不应敏感):

python
no_decay_params = []
decay_params = []
for name, p in model.named_parameters():
    if "bias" in name or "norm" in name or p.ndim == 1:
        no_decay_params.append(p)
    else:
        decay_params.append(p)

optimizer = torch.optim.AdamW([
    {"params": decay_params, "weight_decay": 0.1},
    {"params": no_decay_params, "weight_decay": 0.0},
], lr=lr, betas=(0.9, 0.95))

显存占用

每参数:master FP32 (4) + m FP32 (4) + v FP32 (4) = 12 字节。加上模型 (2) + 梯度 (2) = 16 B/param。

70B 模型 → 1120 GB(不分片)。这就是 ZeRO 的动机。

7.3.2 Adafactor

Shazeer & Stern (2018) "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost"。

创新:秩-1 近似

Adam 的 v(二阶矩)对每个参数都要存一份。Adafactor 对矩阵参数做秩-1 分解:

VRm×n 是某层的二阶矩矩阵。Adafactor 只维护两个向量:

RRm:Ri=1njVij(行均值)CRn:Cj=1miVij(列均值)V^ij=RiCj/(1mkRk)

显存:O(m+n) 而非 O(mn),对 4096×4096 矩阵从 16M 元素降到 8K,节省 1000+ 倍

不需要 ϵ 与 momentum

Adafactor 推荐:

  • 不存一阶动量(默认)
  • 用 RMS clipping 替代 ϵ

进一步节省内存。T5 / PaLM 用 Adafactor 训练。

缺点

  • 在小模型/小数据上表现略差于 Adam
  • 实现复杂
  • 当前 LLM 主流仍然用 AdamW + ZeRO

7.3.3 Lion

Chen et al. (2023) "Symbolic Discovery of Optimization Algorithms",由 Google 用进化搜索找到的优化器。

公式

ct=β1mt1+(1β1)gtθt=θt1η(sign(ct)+λθt1)mt=β2mt1+(1β2)gt

注意:

  • ct 用于更新,但 mt 用不同的 β2 更新(典型 β1=0.9,β2=0.99
  • sign() 让所有参数更新等模长
  • 不需要 v,所以只有 1 份动量(vs Adam 的 2 份)

超参调整

由于 sign 让步长不依赖 gradient 量级:

  • 学习率 η 比 AdamW 小 3-10 倍
  • weight decay 大 3-10 倍

优劣

维度AdamWLion
显存 (优化器)8 byte/param4 byte/param
速度1.0x~1.0x(计算量相近)
Qualitybaseline持平或略好
调参敏感性高 (lr/wd 必须搭配)

PaLM-2、Gemini 部分实验用 Lion。LLaMA 系列仍用 AdamW。

7.3.4 SOAP, Sophia, Muon

2024-2025 年涌现的二阶/混合方法。

SOAP (Shampoo + Adam, 2024)

Shampoo (Gupta 2018) 用 Kronecker 分解的近似 full-matrix preconditioner。SOAP 在 Shampoo 的特征基下跑 Adam:

θt=θt1ηQAdam(Qg)

Q 是 Shampoo 预条件器的特征向量矩阵,每隔 Ts 步更新一次。

实测 SOAP 比 AdamW 快约 1.4x(同 step 数达到同 loss)。Anthropic 等实验室在用。

Muon (Keller Jordan, 2024)

针对矩阵参数:用 Newton-Schulz 迭代正交化梯度后再更新。

Gortho=NewtonSchulz(G),Wt=Wt1ηGortho

正交化让每行/列范数一致,等模长更新。

NanoGPT speedrun (2024) 显示 Muon 比 AdamW 在中等模型上快 1.5-2x。但对超大模型(>10B)是否仍占优尚在验证。

Sophia (Liu 2023)

二阶方法,用 Hutchinson 估计 Hessian 对角线。Adam 是用 g2 估计二阶信息,Sophia 直接用 Hessian。

实测 Sophia 在 1B 级别模型上比 AdamW 快 ~2x。但工程复杂度高,主流 LLM 训练尚未广泛采用。


7.4 学习率调度

7.4.1 Warmup

训练初期用线性增长学习率:

ηt=ηmaxtw,tw

w 是 warmup 步数,通常取总步数的 1-2%(如总 1M 步 → warmup 8K-20K 步)。

为什么需要 warmup

  • Adam 初期二阶矩 v^ 估计不准(方差大),过大 lr 会让步长无界
  • 模型刚初始化时表征不稳,剧烈更新会破坏 layer 间的协同
  • BF16 训练初期 loss spike 风险高

线性 vs 二次 vs 平方根:实证差异很小,线性最常见。

7.4.2 Cosine 衰减

Warmup 后用余弦曲线衰减:

ηt=ηmin+12(ηmaxηmin)(1+cos(π(tw)Tw))

其中 T 是总步数。

优势

  • 平滑下降,无突然变化
  • 末期 lr 自然趋小,提高最终精度
  • 是 GPT-3、LLaMA 等的默认选择

ηmin 取多少?通常 ηmax10%。完全到 0 容易让模型"卡住"。

python
def cosine_lr(t, lr_max, lr_min, warmup, total):
    if t < warmup:
        return lr_max * t / warmup
    progress = (t - warmup) / (total - warmup)
    return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))

7.4.3 WSD (Warmup-Stable-Decay)

也叫 Trapezoidal scheduler。MiniCPM、DeepSeek 部分实验使用。

   lr
    |
η_max┤    _________________
    |   /                  \
    |  /                    \
    | /                      \
    |/                        \___
    └──┴────────────────────┴──┴── t
   warmup    stable        decay

阶段:

  1. Warmup:线性增长到 ηmax(与 cosine 同)
  2. Stable:恒定 ηmax,占总训练 70-90%
  3. Decay:最后 10-20% 衰减(cosine 或 1-sqrt)

优势

  • 可中途 fork:在 stable 阶段任意 ckpt 都可以"接着训",不像 cosine 必须预设总步数
  • 多版本释放:同一基座训不同长度的下游模型(如 1T → 1.4T → 2T tokens 各自 decay)
  • Continual training 友好:上游训完,下游微调用同样的 lr 直接接

DeepSeek-V3 训练分多阶段:8.1T 主预训练(WSD 风格) + 1.4T 中英平衡 + 0.3T 长上下文,每阶段 stable + decay。

1-sqrt 衰减

ηt=ηmax(1(tt0)/(Tt0))

比 cosine 衰减更陡,末期 lr 更小,对"质量退火"更激进。

7.4.4 Inverse Sqrt

ηt=ηmaxw/max(t,w)

T5 用。优点:完全确定,无需预设总步数;缺点:末期衰减慢,最终 lr 偏大。

7.4.5 学习率峰值的 scaling

经验法则:

ηmaxηrefbatch ratio1/2

batch 翻倍 → lr 增加 2 倍。但实际更复杂,需要按模型尺寸调(μP 提供理论框架,见 §7.6.6)。

LLaMA-2 7B 用 lr=3×104,70B 用 1.5×104(更小);DeepSeek-V3 用 4.2×104


7.5 Scaling Laws

7.5.1 Kaplan (OpenAI 2020)

Kaplan et al. "Scaling Laws for Neural Language Models" 是 LLM scaling 的开山之作。

Loss 作为 N(参数)和 D(数据)的函数:

L(N,D)(NcN)αN+(DcD)αD

经验拟合:αN0.076,αD0.095

结论:固定计算 C=6ND,把 C 全投到 N(即 N 大,D 小)。

问题:这个结论后来被 Chinchilla 推翻——Kaplan 的实验对 lr 调度 / batch 选择不充分,导致小模型 under-trained 而被低估。

7.5.2 Chinchilla (DeepMind 2022)

Hoffmann et al. "Training Compute-Optimal Large Language Models" 重新拟合:

L(N,D)=E+ANα+BDβ

经验:α0.34,β0.28E1.69(不可约 loss)。

计算最优

固定 C=6ND,最小化 L

LNNC+LDDC=0

代入约束 D=C/(6N),求得:

NoptCa,DoptCba=βα+β0.45,b=αα+β0.55

所以 ND 大致等比扩展。

经验比例 (Chinchilla):

Dopt/Nopt20

7.5.3 Chinchilla 的颠覆性

DeepMind 用此 law 训了 Chinchilla 70B + 1.4T tokens(D/N = 20:1),击败 Gopher 280B + 300B tokens(D/N = 1.07)——70B 模型用 1/4 的参数和等量计算赢了 280B

发现 Gopher 严重 under-trained,应当训 4-5T tokens 才达到 Chinchilla optimal。

GPT-3 175B + 300B tokens(D/N = 1.7)也是 under-trained。

7.5.4 LLaMA-3 的颠覆 Chinchilla

Chinchilla optimal 是training-optimal,但部署/推理成本没考虑。

LLaMA-3 8B 用 15T tokens(D/N = 1875:1),远超 Chinchilla 的 160B optimal(20:1)。原因:

  • 推理成本主导:8B 模型部署到亿级用户,每天推理成本远超训练
  • 小模型多训数据仍能持续提升(loss vs logD 在 15T 仍是 log-linear)
  • 训练成本 vs 终身部署成本:训 100x 的数据划算

LLaMA-3 论文:

"We find that even at the 15T token scale, the model is still improving in a log-linear fashion."

这说明 Chinchilla scaling law 是"等高线"(同等计算下最优),但实际选择要考虑部署成本——通常是深度 over-train 小模型

7.5.5 DeepSeek MoE Scaling Law

DeepSeek-AI (2024) 针对 MoE 的 scaling law:

用激活参数而非总参数

MoE 模型总参数 671B 但激活只有 37B。Loss 拟合应该用 Nact

L(Nact,D)=E+ANactα+BDβ

IsoFLOP 实验

固定 C,扫描 (Nact,D) 找最小 loss。结果:

  • MoE 的最优 D/Nact 比稠密大(因为有更多专家分摊)
  • 数据质量越高,可以训更多 token 而不饱和

数据质量加权

引入"有效数据" Deff=qDq 是质量系数(教育数据 q>1,通用 web q1)。

DeepSeek-V3 的训练 batch 大小、lr 等都是基于这套 scaling law 推导。

7.5.6 Scaling Law 的局限

Scaling law 是回归方程,实际模型还受:

  • 数据 mix 比例
  • 优化器选择
  • Tokenizer 压缩率
  • 架构细节(MoE / GQA / SwiGLU)

影响。盲目套用公式可能误导,需要小规模 ablation 验证。


7.6 训练稳定性

7.6.1 梯度裁剪

防止偶发的梯度爆炸毁掉训练:

ggmin(1,cg)

全局范数 clip c=1.0 是 LLM 标配。注意是所有参数的梯度作为一个大向量算 L2 范数。

python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

监控指标:clip 触发率应 < 1%。频繁 clip 说明 lr 过大或数据有问题。

7.6.2 z-loss

PaLM (Chowdhery 2022) 引入:

Lz=104(logZ)2,Z=iezi

z 是输出层 logits(softmax 前)。logZ 是 partition function。

作用:限制 logits 的整体偏移,防止训练中 logits 数值漂移导致 softmax 数值不稳。

加入总 loss:L=LLM+104Lz

GPT-4、Gemini、DeepSeek 等大型模型都用。

7.6.3 Embedding Scaling

某些模型在输入 embedding 上加 d 缩放:

xt=dEmb(t)+PE(t)

GPT-J、LLaMA 共享 embedding 与 unembedding,把 d 因子吸收到初始化中。

7.6.4 Skip Connection Scaling (DeepNet)

Wang et al. (2022) "DeepNet" 训了 1000 层 Transformer:

xl+1=αxl+Sublayer(LN(xl))

α<1 防止深层残差累积爆炸。DeepNet 推荐 α=(2L)1/4

LLaMA 等 < 100 层模型不需要这个。

7.6.5 初始化

标准 Pre-LN Transformer:

σinit=25d

GPT-2 / GPT-3 风格。Output projection 与 FFN 的 W2 进一步缩 2L 倍:

σW2=12Lσinit

控制残差路径上的方差累积。

7.6.6 μP (Maximal Update Parametrization)

Yang & Hu (2022) "Tensor Programs V":

目标:让最优超参(学习率、init scale、attention scale)在不同模型宽度 d保持一致

意义:在小模型(d=128)上做 hyperparameter sweep,最优配置可以直接迁移到大模型(d=12288)。

核心规则

层类型LR scaleInit scale
Embedding (输入)11/d
隐层 (matrix)1/d1/d
输出层1/d1/d

注:scale 是相对于"Standard Parametrization (SP)"的差异。

Attention scale

μP 把注意力分数除以 dh(不是 dh):

S=QKdh

但等价于把 WQ 的 init scale 多除 dh,所以实际公式不变。

实证

  • Cerebras 报告用 μP 把 hyperparameter 从 40M 模型迁移到 13B,几乎没有 quality 损失
  • Anthropic 在 Claude 训练前用 μP 在小模型扫超参
  • Mistral / Qwen 部分实验用 μP

局限

μP 推导基于 infinite-width limit,实际有限宽度有偏差。一般用作"超参 anchor",再小调。

7.6.7 Loss Spike 处理

BF16 训练偶发 loss 突然飙升 100x,原因可能是:

  • 数据中长重复段(如 "ababab..." × 10000)
  • 数据中低质量片段(乱码、机器拼接)
  • 优化器状态偶发数值异常

处理

  1. 跳过 batch:检测 loss > k × 移动平均 → 跳过此 micro-batch,optimizer 不更新
  2. 回滚 ckpt:spike 后 loss 不恢复 → 回到上一 ckpt,跳过引发 spike 的 batch 重训
  3. 数据清洗:保存 spike 的 batch,事后分析数据质量问题

DeepSeek-V3 报告训练全程零回滚——通过细致的数据清洗 + bias-based router + FP8 per-block scaling 等多重保险。

7.6.8 监控指标

指标健康范围异常含义
Train loss平滑下降spike → 数据/优化器问题
Gradient norm0.1 - 10持续 > 100 → 不稳定
Loss scale (FP16)稳定频繁缩 → 梯度溢出
MFU35-60%突降 → 通信/IO 问题
Token throughput稳定抖动 → 节点故障
Validation loss下降偏离 train 太多 → 过拟合
Z-loss< 0.01偏大 → logits 数值飘

监控工具:W&B、TensorBoard、Aim、自建 dashboard。LLaMA-3 训练用了 Llama Stack 内建监控 + 自定义 alert。


7.7 完整训练 Loop 示例

下面是一个综合上述所有要点的训练 loop 骨架:

python
import torch
import torch.distributed as dist
from torch.cuda.amp import autocast
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# 1. 模型 + 分布式
model = build_model().cuda()
model = FSDP(model, mixed_precision=bf16_policy, sharding_strategy=FULL_SHARD)

# 2. 优化器(区分 decay / no-decay 参数)
decay_params, no_decay_params = split_params(model)
optimizer = torch.optim.AdamW([
    {"params": decay_params, "weight_decay": 0.1},
    {"params": no_decay_params, "weight_decay": 0.0},
], lr=lr_max, betas=(0.9, 0.95), eps=1e-8)

# 3. lr scheduler
def get_lr(step):
    if step < warmup_steps:
        return lr_max * step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))

# 4. 训练 loop
for step in range(total_steps):
    # 设 lr
    lr = get_lr(step)
    for pg in optimizer.param_groups:
        pg["lr"] = lr

    # 梯度累积
    optimizer.zero_grad()
    accum_loss = 0
    for k in range(accum_steps):
        batch = next_batch()
        with autocast(dtype=torch.bfloat16):
            logits = model(batch.tokens)
            loss_lm = cross_entropy(logits, batch.labels)
            loss_z = z_loss_coef * (logits.logsumexp(-1) ** 2).mean()
            loss = (loss_lm + loss_z) / accum_steps
        loss.backward()
        accum_loss += loss.item()

    # 梯度裁剪
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # 检查 spike
    if accum_loss > 5.0 * moving_avg_loss:
        log(f"SPIKE detected at step {step}, skip update")
        optimizer.zero_grad()
        continue

    optimizer.step()

    # 监控
    if step % 10 == 0 and is_rank_zero():
        log({
            "step": step,
            "loss": accum_loss,
            "grad_norm": grad_norm.item(),
            "lr": lr,
            "tokens_per_sec": ...,
        })

    # checkpoint
    if step % ckpt_interval == 0:
        save_checkpoint(model, optimizer, step)

7.8 实战 Tips

  1. 先在小尺度验证:把所有逻辑(数据、模型、优化器、并行)在 124M-1B 规模跑通,再扩到大模型
  2. 超参从已知配方开始:LLaMA-2 / LLaMA-3 / DeepSeek 都公开了详细配方,照着做再小调
  3. 每个变量只改一个:scaling 实验中,一次只换一个变量(lr / batch / model size),便于归因
  4. Validation 设置好:保留 0.1-1% 的高质量数据作 val,每 1000 step 评一次
  5. Loss curve 与 throughput 是仪表盘:盯着 W&B 看,发现异常立即停训诊断
  6. checkpoint 频率不要省:1-4 小时一次完整 ckpt,恢复成本远小于丢失训练
  7. 多备份代码与配置:训练脚本、tokenizer model、分词后的数据 binary、随机种子,全部纳入版本控制
  8. 预算给 debug:千卡训练前预留 5-10% 的总预算用于故障/调参

7.9 本章小结

  1. 混合精度:FP16 + Loss Scaling → BF16(现代默认)→ FP8 per-block(H100 时代)。DeepSeek-V3 是首个公开 FP8 训练。
  2. 梯度累积 让小卡训大 effective batch;梯度检查点 用 33% 计算换 O(L) 显存。
  3. AdamW 是 LLM 默认优化器,β2=0.95,weight decay 0.1,bias/norm 不 decay;显存 12 B/param。
  4. 学习率:linear warmup + cosine decay 是经典;WSD 在 multi-stage / continual training 中更灵活。
  5. Scaling Laws:Chinchilla 给出 D20N 的 training optimal;LLaMA-3 颠覆为 D20N 的 inference-optimal。
  6. 稳定性:grad clip @ 1.0,z-loss 104,μP 跨尺度迁移超参,spike skip / rollback 兜底。
  7. 工程实践:先小尺度验证 → 大尺度跑;监控 dashboard 必备;checkpoint 是生命线。

至此预训练核心技术全部讲完。下一篇起,我们进入 SFT 与 RL 阶段。


7.10 思考题

  1. FP8 的 per-block scaling 必要性:DeepSeek-V3 用 1×128 行块缩放输入、128×128 块缩放权重。请定量分析:若改为 per-tensor scaling(每个 weight 矩阵一个 scale),FP8 GEMM 的最大相对误差会扩大多少倍?为什么 per-block 对长序列训练特别重要?

  2. AdamW 显存 vs Lion:训练 70B 模型,FP16 + ZeRO-3,64 卡。请分别计算 AdamW 与 Lion 的优化器状态显存(per GPU)。Lion 节省多少 GB?这能否让单卡装得下原本需要 offload 的模型?

  3. Scaling Law 推论:DeepSeek-V3 用 14.8T token 训 671B 总参(37B 激活)的 MoE。按 Chinchilla 用激活参数计算,"compute-optimal" 应该用多少 token?为什么 DeepSeek 选择 over-train(vs LLaMA-3 类似策略)?

  4. WSD vs Cosine 的 fork 实验:你在 stable 阶段(lr = 4×104 恒定)训了 5T tokens,想 fork 出两个版本:A 继续训到 8T、B 直接 decay 到 4×105 训 0.5T。两个版本的最终 loss 大致差距是多少?请用 scaling law 估算。

  5. Loss Spike 实战:你在 step 80000 / 1000000 检测到 loss 从 2.1 跳到 8.5。你需要:(a) 立即采取什么动作?(b) 如何定位是数据问题、还是优化器问题、还是数值问题?(c) 写一个简洁的诊断 checklist。

基于 MIT 协议发布