6. 分布式训练
导读:训练 70B+ 的模型需要数千张 GPU,单卡显存与计算都远远不够。本章讲解从 DDP 到 ZeRO、张量并行、流水线并行、3D 并行的完整体系,包含 Megatron-LM、DeepSpeed、FSDP 的核心机制与通信开销分析。
6.1 单卡的极限:为什么必须分布式
6.1.1 显存开销分解
训练一个参数量
| 组件 | 大小 (字节) | 说明 |
|---|---|---|
| 模型参数 | FP16 | |
| 梯度 | FP16 | |
| AdamW master 参数 | FP32 | |
| AdamW | FP32 | |
| AdamW | FP32 | |
| 总计 | ||
| 激活 | 反向需要 |
LLaMA-2 7B:
激活更夸张:seq_len = 4096,batch = 8,约 50-80 GB(不算 checkpoint)。
6.1.2 计算开销
LLaMA-2 70B 训 2T tokens:
H100 BF16 峰值 989 TFLOPS,按 50% MFU 算:
必须用数千卡并行,把训练时间降到几周到几个月。
6.1.3 并行的四个维度
按"分什么"区分:
| 维度 | 分什么 | 通信类型 | 通信量 |
|---|---|---|---|
| DP (数据并行) | batch 切分 | 梯度 AllReduce | |
| TP (张量并行) | 矩阵切分 | 激活 AllReduce | 每层多次,与 batch × seq 成正比 |
| PP (流水线并行) | 层切分 | 激活 P2P | 相邻 stage,与 batch × seq 成正比 |
| EP (专家并行) | MoE 专家切分 | All-to-All | 与 K × token 成正比 |
实际生产用 3D(或 4D)并行:
6.2 数据并行的演进
6.2.1 朴素 DP
每张 GPU 复制完整模型,处理不同 batch 切片,反向后用 AllReduce 同步梯度。
# 概念伪代码
model = Model().cuda()
broadcast(model, src=0) # 所有 rank 同步初始权重
for batch in dataloader:
local_batch = scatter(batch)
loss = model(local_batch)
loss.backward()
for p in model.parameters():
p.grad = all_reduce(p.grad) / world_size # 平均
optimizer.step()显存:每卡
6.2.2 DDP (Distributed Data Parallel)
PyTorch 标准实现,相比朴素 DP 的关键优化:
- 桶化 (bucketing):把多个梯度合并成一个 bucket(默认 25MB),AllReduce 一次
- 重叠计算与通信:反向传播时,某层梯度算完后立即触发 AllReduce,与后续层的反向计算 overlap
- 梯度累加无通信:
no_sync()context 内多次反向不通信,最后一次 sync
import torch.nn.parallel as parallel
model = parallel.DistributedDataParallel(
model.cuda(),
device_ids=[local_rank],
bucket_cap_mb=25,
find_unused_parameters=False,
)通信量:每 step
6.2.3 FSDP (Fully Sharded Data Parallel)
PyTorch 1.11+ 内置的 ZeRO-3 等价实现。每张 GPU 只持
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload, MixedPrecision, BackwardPrefetch,
)
model = FSDP(
model,
cpu_offload=CPUOffload(offload_params=False),
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 等价
)显存:
6.3 ZeRO:DeepSpeed 的核心
6.3.1 思想
Rajbhandari et al. (2020) "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" 观察:
DDP 中每张卡持有的"模型状态"(参数 + 梯度 + 优化器状态)是冗余的——每张卡都有完整副本。如果在
6.3.2 三级分片
| Stage | 分片对象 | 单卡显存 | 通信量 |
|---|---|---|---|
| DDP | 无 | ||
| ZeRO-1 ( | 优化器状态 | ||
| ZeRO-2 ( | + 梯度 | ||
| ZeRO-3 ( | + 参数 |
Stage 1: 分片优化器状态
每张卡只更新自己负责的
流程:
- 反向:所有 GPU 计算完整梯度,AllReduce
- 更新:每张 GPU 只更新自己分片的 master +
- AllGather:把更新后的参数分片广播给所有卡
显存:参数
通信:与 DDP 相同(
Stage 2: 分片梯度
每卡只保存自己负责分片的梯度。
流程:
- 反向:每层算完梯度立即 ReduceScatter(每卡只收到自己负责分片的梯度),其他梯度可以释放
- 更新:每张 GPU 用自己分片的梯度更新自己分片的优化器状态
- AllGather:把更新后的参数分片广播
通信:ReduceScatter (
显存:参数
Stage 3: 分片参数
每张 GPU 平时只持自己分片的参数
流程:
- 前向:每层之前 AllGather 当前层参数,计算完释放
- 反向:每层之前 AllGather 当前层参数,反向算完释放;梯度 ReduceScatter
- 更新:每张 GPU 更新自己分片的参数
通信:每层多 1 次前向 AllGather(反向也 1 次)→ 总
显存:
6.3.3 内存公式总结
设激活显存为
当
6.3.4 ZeRO-Offload / ZeRO-Infinity
进一步把状态卸载到 CPU 或 NVMe:
| 方案 | 卸载到 | 增加单卡能跑的最大模型 |
|---|---|---|
| Vanilla ZeRO-3 | GPU only | |
| ZeRO-Offload | CPU 内存 | 更大(CPU 通常 1TB+) |
| ZeRO-Infinity | NVMe SSD | 几乎无限(带宽差) |
8x A100 80GB + 1TB CPU 内存:ZeRO-Offload 可微调 175B 模型。NVMe 还能跑 1T+。
代价:CPU-GPU PCIe 带宽(~25 GB/s)远低于 HBM(~1.5 TB/s),训练速度下降 2-5 倍。
6.3.5 ZeRO 的代码
# DeepSpeed
import deepspeed
ds_config = {
"train_batch_size": 1024,
"fp16": {"enabled": True},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "cpu"},
"overlap_comm": True,
"reduce_scatter": True,
"contiguous_gradients": True,
}
}
engine, _, _, _ = deepspeed.initialize(model=model, config=ds_config)PyTorch FSDP 等价于 ZeRO-3,开箱即用:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)6.4 张量并行 (Tensor Parallelism)
6.4.1 动机
DP/ZeRO 在层之间切分,但单层很大时(如 405B 模型的
6.4.2 MLP 的张量并行
考虑 MLP:
Column Parallel(按列分 )
把
每张 GPU 计算:
GELU 是逐元素的,无需通信。每张 GPU 持有
Row Parallel(按行分 )
把
每张 GPU 计算:
最终
完整流程
GPU i:
Y_i = GELU(X A_i) # 无通信
Z_i = Y_i B_i # 本地 GEMM
Z = AllReduce(Z_i) # 跨 GPU 求和每个 MLP 块:前向 1 次 AllReduce,反向 1 次 AllReduce。
通信量:
6.4.3 自注意力的张量并行
MHA 由
按列切:每张 GPU 持
输出投影
每个 attention 块:前向 + 反向 各 1 次 AllReduce。
6.4.4 整层的通信
Transformer 块前向:
attention block: 1 AllReduce
FFN block: 1 AllReduce
→ 总 2 AllReduce
反向: 2 AllReduce
→ 单层总 4 AllReduce通信量:
6.4.5 与 算子
Megatron-LM 论文用
模型中插入
6.4.6 序列并行 (Sequence Parallelism, SP)
TP 之外,LayerNorm 和 dropout 的输入仍是完整的
SP(Korthikanti et al. 2022)把这些"非 TP 部分"按序列维切分,每张 GPU 处理
通信变化:
- 原 TP:AllReduce
- TP + SP:ReduceScatter (出 TP 块) + AllGather (入下一 TP 块)
ReduceScatter + AllGather = AllReduce 通信量,但激活显存减少
6.4.7 TP 配置选择
- TP 通信发生在每层多次,对带宽要求极高 → 必须用 NVLink (节点内)
通常 = 节点内 GPU 数(A100/H100 节点 8 卡 → ) 跨节点会因 IB 带宽限制(vs NVLink)显著降速
LLaMA-3 405B 训练:TP=8,PP=16,DP=128,总 16384 卡(H100)。
6.5 流水线并行 (Pipeline Parallelism)
6.5.1 思想
把模型按层切到
GPU 0: layer 0-9 (stage 0)
GPU 1: layer 10-19 (stage 1)
GPU 2: layer 20-29 (stage 2)
GPU 3: layer 30-39 (stage 3)前向:每个 stage 算完自己的层,把激活送给下一 stage。反向:反方向传梯度。
通信只发生在相邻 stage 间(P2P send/recv),跨节点带宽要求低。
6.5.2 Naive Pipeline (GPipe)
Huang et al. (2019) 提出的最早方案:
把全局 batch 切成
时间 →
GPU 0: F0 F1 F2 F3 B3 B2 B1 B0
GPU 1: F0 F1 F2 F3 B3 B2 B1 B0
GPU 2: F0 F1 F2 F3 B3 B2 B1 B0
GPU 3: F0 F1 F2 F3 B3 B2 B1 B0
(空闲)Bubble (空闲泡):流水启动和结束时部分 GPU 空闲。
设单 micro-batch 前向时间
为让 bubble ≤ 10%,需要
激活显存:所有
6.5.3 1F1B (One-Forward-One-Backward)
PipeDream (Harlap 2018) 与 Megatron 的核心调度。
思路:流水启动后,前向和反向交替,让激活及时释放。
GPU 0: F0 F1 F2 F3 B0 F4 B1 F5 B2 ...
(B0 释放 F0 激活)
GPU 1: F0 F1 F2 F3 B0 F4 B1 F5 ...
GPU 2: F0 F1 F2 F3 B0 F4 B1 ...
GPU 3: F0 F1 F2 F3 B0 F4 ...激活显存:
Bubble 比例:与 GPipe 相同
6.5.4 Interleaved 1F1B
Narayanan et al. (2021) "Megatron-LM" 提出。每 GPU 持有
v=2 时:
GPU 0: layers [0-3, 16-19] (chunk 0a, 1a)
GPU 1: layers [4-7, 20-23] (chunk 0b, 1b)
GPU 2: layers [8-11, 24-27]
GPU 3: layers [12-15, 28-31]前向:先依次走完所有 chunk a,再走所有 chunk b。
Bubble 比例:
代价:
- 通信量增加
倍 - 要求
是 的倍数
6.5.5 Zero Bubble Pipeline
Qi et al. (2023) "Zero Bubble Pipeline Parallelism" 进一步优化:
关键观察:反向传播包含两部分:
- B:算 input gradient(向前传播)
- W:算 weight gradient(更新本层权重)
W 是局部计算,不阻塞流水。把 B 和 W 分开调度,可以填满 bubble。
ZB-1p 调度(与 1F1B 相同显存):bubble ≈ 1/3 of 1F1B ZB-2p 调度(2x 激活显存):bubble ≈ 0
DeepSeek-V3 训练用类似的 DualPipe 算法,通信和计算几乎完全 overlap。
6.5.6 PP 切分策略
层数
但实际不能均分:
- 第一个 stage 多一个 embedding 层
- 最后一个 stage 多一个 LM head + loss 计算
- 不同层计算量略有差异
Megatron 用 transformer_layers_per_pp_stage 列表手动配置:
# 32 层模型,PP=4,第一个 stage 少 1 层(因为有 embedding)
transformer_layers_per_pp_stage: [7, 8, 9, 8]6.5.7 PP 的通信
每个 stage 之间的边界:前向发激活 (
通信量与
可以放在节点间(IB 带宽够,且不在 critical path)。
6.6 3D 并行
6.6.1 组合规则
总 GPU 数
| 维度 | 通信带宽要求 | 部署位置 |
|---|---|---|
| TP | 极高(每层多次 AllReduce) | 节点内 NVLink |
| PP | 中(相邻 stage P2P) | 节点间 IB 也可 |
| DP | 中(每 step 1 次 AllReduce,可 overlap) | 节点间 IB |
经验法则:TP 节点内 + PP 跨少量节点 + DP 跨剩余节点。
6.6.2 LLaMA-3 405B 训练配置
16384 张 H100:
- TP = 8(节点内)
- PP = 16
- DP = 128
总
每 PP stage 持
6.6.3 DeepSeek-V3 配置
2048 张 H800(NVLink 600 GB/s, IB 50 GB/s 单链路):
- PP = 16
- EP = 64(专家并行)
- TP = 1(因为 MLA 显存压得很小)
- ZeRO-1 DP = 2(剩余)
DeepSeek-V3 论文:选 TP=1 主要因为 H800 互联弱(vs H100 的 NVLink 900 GB/s),TP 通信代价高;MLA 大幅降低单层显存让 TP=1 可行。
6.6.4 通信成本估算
设 hidden
单 step 总通信量(粗略):
- TP AllReduce:
(每层 4 次,FP16) - PP P2P:
(前向 + 反向,每个 stage 边界) - DP AllReduce:
B / step
实际:
- TP 是大头(占 50-70%)
- DP 与计算 overlap(不在 critical path)
- PP 与流水阶段 overlap
6.7 通信原语详解
6.7.1 主流原语
| 原语 | 操作 | 通信量 (per GPU, |
|---|---|---|
| Broadcast | rank 0 → all | |
| Reduce | all → rank 0 | |
| AllReduce | reduce + broadcast | |
| AllGather | 拼接所有分片 | |
| ReduceScatter | reduce 后切分 | |
| All-to-All | 每 rank 发 | |
| P2P (Send/Recv) | 点对点 |
6.7.2 Ring AllReduce
经典算法,最优
设
阶段 1: ReduceScatter(
第
阶段 2: AllGather(
第
每步通信
其中
6.7.3 Tree AllReduce
NCCL 的另一种算法,对**小数据 + 大
NCCL 自动根据数据量选择算法。
6.7.4 NCCL & 硬件
| 互联 | A100 | H100 | H800 |
|---|---|---|---|
| 节点内 NVLink | 600 GB/s | 900 GB/s | 400 GB/s |
| 节点内 NVSwitch | 全互联 | 全互联 | 全互联 |
| 节点间 IB (单卡) | HDR 200 Gb/s | NDR 400 Gb/s | NDR 400 Gb/s |
H800 是 H100 的"出口阉割版":算力相同(FP8 1979 TFLOPS),但 NVLink 砍半,跨节点 IB 不变。这就是 DeepSeek 选择 TP=1(避免 NVLink 瓶颈)的底层原因。
6.7.5 通信重叠
DDP 反向时,每 bucket 梯度算完立即触发 AllReduce,与后续层反向 overlap:
时间 →
计算: B7 B6 B5 B4 B3 B2 B1 B0
通信: AR7 AR6 AR5 AR4 AR3 AR2 AR1 AR0理想情况下通信完全隐藏,实际能 overlap 70-90%。
ZeRO-3 / FSDP 的前向 AllGather 可与上一层的计算 overlap,反向 ReduceScatter 同理。
6.8 激活内存优化
6.8.1 激活的来源
每个 Transformer 层在前向产生激活,反向时需要:
- 输入
(残差用) - attention 的
- LayerNorm 的
- FFN 的中间激活
每 token 单层激活约
LLaMA-2 7B (32 层,
超过单卡显存的 1/3。
6.8.2 梯度检查点 (Gradient Checkpointing)
Chen et al. (2016) "Training Deep Nets with Sublinear Memory Cost":
只保存"检查点"(如每个 Transformer block 的输入),反向时局部重算。
显存:
代价:forward 多算 1 次,FLOPs 增加 ~33%(因为 backward 本身 flops 是 forward 的 2 倍,重算一次 forward = 33% 增加)。
from torch.utils.checkpoint import checkpoint
def forward(self, x):
for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)
return x6.8.3 Selective Activation Recomputation
只重算"便宜"的(如 attention softmax),保留"昂贵"的(如 GEMM 输出)。
Megatron-LM 实现:默认对 attention 部分(QK^T、softmax、PV)做 selective recomputation,FLOPs 增加仅 ~5%。
6.8.4 序列并行 + 激活分片
TP + SP 把激活按序列维分到
LLaMA-3 405B 训练用 TP=8,激活显存压缩 8 倍。
6.8.5 CPU Offload
ZeRO-Infinity 把激活 offload 到 CPU。需要:
- PCIe 带宽足够(A100 PCIe 4.0: 32 GB/s)
- 异步传输 overlap
实测训练速度下降 1.5-2 倍,但能跑下原本不可能的模型。
6.9 大规模训练的工程实践
6.9.1 启动框架
| 框架 | 特点 |
|---|---|
| Megatron-LM | NVIDIA 出品,TP/PP/DP 全栈,3D 并行最成熟 |
| DeepSpeed | 微软,ZeRO 系列,offload 强 |
| FSDP | PyTorch 原生 ZeRO-3 |
| Megatron-DeepSpeed | 两者结合,BLOOM 用过 |
| NeMo | NVIDIA 上层封装 |
| OpenLLM 系列 | LLaMA-Factory, Axolotl,社区中等规模 |
| Colossal-AI | HPC-AI Tech,社区活跃 |
6.9.2 LLaMA-3 训练配置示例
# 简化的 megatron 配置
model:
num_layers: 126
hidden_size: 16384
num_heads: 128
num_kv_heads: 8 # GQA-8
ffn_hidden_size: 53248 # SwiGLU
vocab_size: 128256
distributed:
tensor_parallel_size: 8
pipeline_parallel_size: 16
data_parallel_size: 128
precision:
fp8_format: e4m3 # H100 only
bf16: true # master + 激活
optimizer:
type: AdamW
lr: 8e-5
weight_decay: 0.1
beta1: 0.9
beta2: 0.95
scheduler:
type: cosine
warmup_steps: 8000
total_steps: 1200000
data:
seq_length: 8192
global_batch_size: 16M tokens
micro_batch_size: 1
gradient_accumulation: 1246.9.3 Checkpoint 与容错
千卡训练中节点故障几乎是必然的(每千卡每天 ~1-3 次硬故障)。
Checkpoint 设计:
- 每 1-4 小时存一次完整 ckpt
- 异步 ckpt:后台线程把 GPU → CPU → 持久化存储
- 增量 ckpt:只存变化部分(实验性)
- 健康检查:定期 NCCL all-reduce 测连通性
LLaMA-3 论文报告 16K H100 集群每天 ~9 次故障,平均每次恢复 35 分钟,全程有效训练时间占比 ~90%。
6.9.4 监控指标
| 指标 | 健康范围 |
|---|---|
| MFU (Model FLOPs Utilization) | A100: 35-50%, H100: 50-70% |
| Token 吞吐 | 数千-数万 token/s/GPU |
| Loss | 平滑下降,无尖峰 |
| Gradient norm | 0.1-10,clip 在 1.0 |
| Loss spike rate | < 1 / 1000 step |
| Comm overlap | > 80% |
6.10 通信开销分析(实例)
6.10.1 LLaMA-2 70B 单 step 估算
- 70B 参数 → ZeRO-1 DP AllReduce:
B GB - 64-way DP @ 200 Gb/s IB Ring AllReduce:
s
单 step 计算时间约 5-10 s(取决于 batch 大小),DP 通信占比 5-10%,可与反向 overlap。
6.10.2 LLaMA-3 405B 单 step
- 反向时间约 10 s
- DP-128 ZeRO-3 通信约 3 s(可 overlap 70%)
- TP-8 AllReduce: 每层 4 次 × 126 层 ×
≈ 400 GB(不可 overlap) - TP NVLink 900 GB/s → 0.5 s
- PP P2P:相邻 stage 间,与反向重叠
实际 MFU 约 50-55%。
6.10.3 优化方向
- TP 通信压缩:FP8 通信、低秩近似
- DP-TP 联合:把 DP AllReduce 与 TP AllReduce 共线优化
- 拓扑感知:让通信链路与物理拓扑对齐
- DualPipe(DeepSeek):PP 与 EP 全 overlap
6.11 本章小结
- DP → DDP → ZeRO → FSDP 逐步分片模型状态,从
单卡到 。 - ZeRO 三级:仅优化器状态、+ 梯度、+ 参数,通信量从
增到 ,显存从 减到 。 - 张量并行 (TP):MLP 列并行 + 行并行 + AllReduce,attention 按头切分;通信高频,必须 NVLink 节点内。
- 流水线并行 (PP):1F1B 调度激活显存
,bubble 比例 ;Interleaved 与 Zero Bubble 进一步缩小 bubble。 - 3D 并行:
,根据互联拓扑决定每维大小。LLaMA-3 405B 用 TP=8 PP=16 DP=128。 - 通信原语:Ring AllReduce
最优带宽利用,Tree AllReduce 延迟优;NCCL 自动选择。 - 激活优化:梯度检查点 + 序列并行 + selective recomputation,缺一不可。
下一章我们讨论训练工程的最后一公里——精度、优化器、scaling law、稳定性。
6.12 思考题
ZeRO 分级显存推导:参数量
, ,FP16 + AdamW。请分别计算 DDP / ZeRO-1 / ZeRO-2 / ZeRO-3 单卡静态显存占用,并说明在 H100 80GB 上各自能否放下激活(假设激活 30GB)。 3D 并行最优配置搜索:在 4096 张 H100(节点 8 卡 NVLink,节点间 NDR 400 Gb/s IB)上训 175B 模型。请设计 TP × PP × DP 配置,使得 (a) MFU 最大化,(b) 单 step 通信时间不超过计算时间的 20%。给出至少 2 种合理配置并对比。
Bubble Ratio 优化:1F1B 流水线
,micro-batch 。请计算朴素 bubble、Interleaved 1F1B ( ) bubble、Zero Bubble (ZB-2p) bubble。在显存允许的前提下,哪种调度的有效 throughput 最高? TP 通信瓶颈分析:H800 NVLink 400 GB/s,BF16 训练 LLaMA-405B(
,TP=8)。每层 TP AllReduce 通信量是多少?8 卡 Ring AllReduce 时间是多少?若用 RDMA over IB(50 GB/s)会慢多少倍?这解释了为什么 TP 必须放在节点内。