5. MoE 架构
导读:MoE (Mixture of Experts) 通过稀疏激活,让模型总参数量翻倍而 FLOPs 不变。它是 GPT-4、Mixtral、DeepSeek-V3 等顶级模型的共同选择。本章从基本公式出发,讲解 Switch Transformer、Mixtral、DeepSeek-MoE 的架构演进,剖析负载均衡、容量因子、Expert Choice 等核心机制,并讨论专家并行的通信开销。
5.1 为什么需要 MoE
5.1.1 稠密模型的 Scaling 瓶颈
按 Chinchilla scaling law,模型
- 训练成本:
FLOPs, 翻倍则 FLOPs 翻倍 - 推理成本:每生成 1 token 都要走完所有
个参数,延迟随 线性增长
是否存在"参数量大但 FLOPs 不变"的架构?这就是 MoE 的目标。
5.1.2 MoE 的核心思想
人脑约 860 亿神经元,但任意一个时刻只有 1-3% 神经元在放电。同理,对每个 token,没必要让所有参数都参与——只激活与该 token 相关的子模块。
把 FFN 替换为
参数量:
5.1.3 历史脉络
| 时间 | 模型 | 关键创新 |
|---|---|---|
| 1991 | Jacobs et al. | 提出 MoE 概念(adaptive mixture of local experts) |
| 2017 | Shazeer et al. "Outrageously Large NN" | LSTM-based MoE,1.4 万亿参数 |
| 2020 | GShard | Top-2 路由,扩展到 600B 参数 |
| 2021 | Switch Transformer | Top-1 简化路由,1.6T 参数 |
| 2022 | GLaM | 1.2T 参数,性能接近 GPT-3 |
| 2023 | Mixtral 8×7B | 首个公开高质量 MoE,47B/13B |
| 2024 | DeepSeek-MoE | 细粒度专家 + 共享专家 |
| 2024-12 | DeepSeek-V3 | 671B/37B,无辅助 loss 负载均衡,MTP 训练 |
5.2 基本 MoE 数学
5.2.1 路由器与门控
输入 token 向量
其中
5.2.2 Top-K 选择
把 logits 经过 softmax 得到概率分布,再取最大的
或者先取 Top-K 再 softmax(更常见):
两种做法在
5.2.3 输出聚合
5.2.4 计算量
每 token:
- Router:
个 expert FFN:
相比稠密 FFN,计算量近似
但显存占用是
5.3 Switch Transformer
Fedus et al. (2021) "Switch Transformers" 提出Top-1 路由的极简 MoE:每 token 只选 1 个专家。
5.3.1 设计动机
Top-2 之前是默认(GShard),但:
- Top-1 实现更简单(无需 softmax 重归一化)
- 通信只需 1 次 all-to-all
- 实测 Top-1 + 更多专家 ≈ Top-2 + 较少专家(在固定计算预算下)
5.3.2 公式
注意:保留 softmax 概率作为权重(不是简单的 1.0),让梯度能流过 router。
5.3.3 容量因子 (Capacity Factor)
理想情况下每个专家收到
引入容量
其中
超容量的 token 怎么办?两种策略:
- Drop(Switch 默认):超过
的 token 直接走残差(残差连接保留输入),相当于这个 token 跳过 MoE 层 - No-Drop:让多余 token 排队,下个 micro-batch 再处理(实现复杂,少见)
Drop 比例称为 token drop rate,监控指标,理想 < 1%。
5.3.4 辅助负载均衡 Loss
如果不约束,router 容易"塌缩"——所有 token 都路由到同一个专家。需要引入辅助 loss鼓励均匀分布。
设
辅助 loss:
通常
为什么这个形式?
- 均匀分布时
, ,乘 后 (min 值) - 极端塌缩时
,其他都是 0, ,乘 后 (max 值) - 梯度:
( 不可微,作为常数),推动 向高 的反方向调整——实际上是惩罚"高频专家进一步被路由"
工程实现:
def switch_aux_loss(probs, indices, num_experts, alpha=0.01):
# probs: [T, N] softmax 输出; indices: [T] argmax
f = torch.zeros(num_experts, device=probs.device)
f.scatter_add_(0, indices, torch.ones_like(indices, dtype=torch.float))
f = f / probs.shape[0] # 实际路由比例
P = probs.mean(0) # 软概率均值
return alpha * num_experts * (f * P).sum()5.3.5 Router z-loss
为防止 router logits 数值过大导致 softmax 不稳:
权重
总 loss:
5.3.6 Switch 配置
Switch-T:
- 编码器层:每隔一层把 FFN 替换为 MoE 层(不是每层都换)
不等 - 总参数 1.6T(最大版)
- 训练吞吐比 T5-XXL 快 7x
5.4 GShard 与 Top-2 路由
GShard (Lepikhin et al. 2020) 是 Switch 之前的方案,Top-2 路由。
每 token 选两个专家:
其中
Mixtral 8×7B 用 Top-2,每 token 2 个专家激活。
5.4.1 Top-2 vs Top-1 对比
| 维度 | Top-1 (Switch) | Top-2 (GShard, Mixtral) |
|---|---|---|
| 计算量 | 1×FFN | 2×FFN |
| 通信 | 1× all-to-all | 2× all-to-all |
| 实现复杂度 | 简单 | 中等 |
| 鲁棒性 | drop 1 个专家就丢信息 | 备份机制 |
| Quality | 较低 | 较高 |
实践:
大( ):Top-1 效率高 小( ):Top-2 质量更好
5.4.2 Mixtral 8×7B 配置
- 8 个专家,每个相当于一个 Mistral 7B 的 FFN
- Top-2 路由
, (Mistral 风格 SwiGLU) - 32 层 decoder
- 每层都是 MoE(不像 Switch 隔层)
- 总参数 47B, 激活参数 13B
- 仅在 FFN 用 MoE,attention 仍是稠密 GQA
实测 Mixtral 8×7B 在多数 benchmark 上接近 LLaMA-2 70B(13B 激活 vs 70B 激活),是 MoE 路线的关键证据。
5.5 DeepSeek-MoE
Dai et al. (2024) "DeepSeekMoE" 针对传统 MoE 的两个问题:
- 知识混杂:每个专家内部学了过多重叠知识
- 专家冗余:相同基础能力(如语法、常识)在多个专家重复存储
5.5.1 创新 1:细粒度专家分割
思路:保持总参数不变,把专家做小、变多。
设原 MoE:
个专家 个激活
总参数
好处:
- 专家组合数大大增加:原
种组合 → 现 种 - 每个专家更"专精",可以捕获更细致的语义
- 计算量近似不变(FLOPs ∝
,相同)
DeepSeek-V3 实际配置:256 个路由专家,Top-8 激活。
5.5.2 创新 2:共享专家隔离
某些通用知识(语法规则、常用词义)应该被所有 token 用到,让普通专家学这些是浪费。
引入共享专家 (shared expert):
个共享专家,强制激活(所有 token 都过) 个路由专家,从中选 个
DeepSeek-V3 用
效果:
- 共享专家承担"基础设施",路由专家专注"特化能力"
- 路由更平衡(不需要每个专家都掌握通用知识)
- 实测同 FLOPs 下能力提升 2-5%
5.5.3 创新 3:无辅助 Loss 负载均衡 (DeepSeek-V3)
问题:辅助 loss
DeepSeek-V3 创新:给每个专家加可调偏置
注意:
- TopK 用
选择 - 但
取的是未加偏置的 值
亲和度评分
偏置更新
每个 micro-batch 后:
- 若专家
过载( ): - 若专家
欠载:
类似自动控制:偏置充当"反馈控制器",把负载推回均衡。
序列级辅助 loss(轻量)
虽然全局靠偏置,但单个序列内仍可能极端不均(如某序列全是数学,把数学专家用爆)。补充一个极小的 sequence-wise loss:
效果:DeepSeek-V3 在 14.8T token 上几乎没有性能损失就实现了负载均衡,是 MoE 工程的重要进展。
5.5.4 路由策略对比
| 模型 | 共享 | 负载均衡 | ||
|---|---|---|---|---|
| Switch-T | 32-2048 | 1 | 0 | aux loss + z-loss |
| GShard | 2048 | 2 | 0 | aux loss |
| Mixtral 8×7B | 8 | 2 | 0 | aux loss |
| DBRX | 16 | 4 | 0 | aux loss |
| DeepSeek-V2 | 160 + 2 | 6 | 2 | aux loss + device-level |
| DeepSeek-V3 | 256 + 1 | 8 | 1 | bias-based, 几乎无 aux |
| Qwen2-MoE | 64 + 4 | 4 | 4 | aux loss |
5.6 训练稳定性
5.6.1 路由崩溃 (Route Collapse)
少数专家被频繁选中,其他专家几乎不被激活。
原因:
- 训练初期 router 随机,某些专家偶然胜出
- 胜出的专家得到更多梯度,更善于处理任意 token,进一步胜出("富者愈富")
解决:
- 辅助 loss / 偏置(DeepSeek)
- Expert Choice(见下)
- Token-choice 切换为 expert-choice
5.6.2 Token Drop 损失
容量超限的 token 被 drop,相当于这部分 token "白训了"(仍用残差,但 MoE 层贡献为零)。
监控:每个 batch 的 drop rate 应 < 1%;> 5% 说明容量因子过小。
解决:
- 增大
(默认 1.0 → 1.25) - Pad-and-Drop:drop 后再补充随机 token 维持 batch shape
- 训练时用 dropless batching:动态调整 batch 内 token 数
5.6.3 数值不稳
症状:训练中途 loss 突然爆炸、router logits → ∞
原因:
- Router 是线性层,没有归一化,logits 可能任意大
- softmax 的 exp 在大值时溢出
解决:
- Router z-loss
- Router 输出 FP32 计算(即使其他用 BF16)
- Logits clip:
5.6.4 专家初始化
每个专家独立初始化的话,training 初期可能某专家恰好"看起来好"。
技巧:
- 共享专家从相同种子初始化
- 路由专家用相同 RMSNorm scale,避免幅度差异
5.7 Expert Choice:换个思路
Zhou et al. (2022) 提出 Expert Choice (EC) 路由:
反向选择:不是 token 选 expert,而是 expert 选 token——每个专家从所有 token 中选 Top-K 个最适合自己的。
5.7.1 公式
设
每个专家选自己排名前
5.7.2 性质
| 性质 | Token-Choice | Expert-Choice |
|---|---|---|
| 负载均衡 | 需要 aux loss | 天然均衡 |
| 每 token 专家数 | 固定 | 不固定(可能 0、可能多) |
| 因果性 | 保留 | 破坏(需要看完整 batch 才能选) |
| 训练适用 | 是 | 是 |
| 推理适用 | 是 | 否(自回归无法预知后续 token) |
EC 的"天然均衡"非常吸引人,但因果性问题让它只能用于训练(实际工程中也少见,因为推理切回 token-choice 会引入分布偏移)。
5.7.3 GShard EC 变体
部分实现用 EC 训练 + token-choice 推理,效果中等。DeepSeek 路线选择了 token-choice + 偏置均衡,不走 EC。
5.8 专家并行 (Expert Parallelism)
5.8.1 基本思路
把
5.8.2 All-to-All 通信
每层 MoE 的前向流程:
- 本地 GPU 算完 attention,得到 token 表示
- All-to-All 1:把每个 token 发到它的目标专家所在的 GPU
- 各 GPU 上的专家计算(本地 GEMM)
- All-to-All 2:把结果送回原 token 所在的 GPU
- 加权求和(如果 Top-K, K>1)
每次 all-to-all 的通信量:
5.8.3 DeepSeek-V3 的 EP
DeepSeek-V3 用 64-way 专家并行(256 专家分到 64 GPU),跨节点 all-to-all 是性能瓶颈。
优化:DualPipe + 通信内核
DualPipe 是 DeepSeek 自创的流水线并行,把 PP 与 EP 的通信完全 overlap:
- PP 做反向计算时,EP 做下一 micro-batch 的 all-to-all
- 用 NVLink + IB 双链路,单链路打满
约束:每 token 跨 4 个节点的限制(避免通信爆炸)。这通过节点级路由实现:先选节点(看每节点上专家的最大亲和度),再选节点内专家。
5.8.4 通信开销估算
设
每层 MoE 通信量(FP16):
DeepSeek-V3:
61 层(含 MoE 层)→ 每个 step 单方向通信
H800 IB 带宽 50 GB/s(双向 400 Gb/s),
5.9 MLA (Multi-head Latent Attention)
DeepSeek 的另一个创新(虽然不是 MoE 本身),与 MoE 一起构成 DeepSeek-V2/V3 架构。
5.9.1 动机
MoE 大幅压缩了 FFN 参数(37B 激活/671B 总),但 attention 部分仍然稠密。在长上下文推理时,KV-Cache 成为新瓶颈。
GQA-8 已经把 KV 缩小 8 倍,能否更激进?
5.9.2 MLA 设计
把 K、V 投影到一个低维 latent 向量
只缓存
推理时再"解压":
由于注意力是
直接在 latent 空间做注意力。
5.9.3 KV-Cache 对比
DeepSeek-V2 配置:
加上 RoPE 兼容设计(
5.9.4 MLA + MoE = DeepSeek-V3
V3 配置:
- 61 层(其中 58 层 MoE,3 层稠密)
- 注意力:MLA,
, , - MoE:256 路由专家 + 1 共享,Top-8
- 671B 总参,37B 激活
- 14.8T token 训练,FP8 mixed precision
- 上下文 128K(YaRN 扩展)
效果:在多数 benchmark 上比肩 GPT-4o,开源、可商用,是 2024 年最重要的开源 LLM。
5.10 工程实现要点
5.10.1 PyTorch 版 MoE 层
class MoELayer(nn.Module):
def __init__(self, dim, hidden_dim, num_experts, top_k):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gate = nn.Linear(dim, num_experts, bias=False)
self.experts = nn.ModuleList([
FFN(dim, hidden_dim) for _ in range(num_experts)
])
def forward(self, x):
# x: [B, T, d]
B, T, D = x.shape
x_flat = x.view(-1, D) # [B*T, d]
N = x_flat.size(0)
# Router
logits = self.gate(x_flat) # [N, E]
topk_val, topk_idx = logits.topk(self.top_k, dim=-1) # [N, K]
topk_val = F.softmax(topk_val, dim=-1)
# 朴素实现:遍历专家
out = torch.zeros_like(x_flat)
for e in range(self.num_experts):
mask = (topk_idx == e).any(dim=-1)
if not mask.any():
continue
tokens = x_flat[mask]
expert_out = self.experts[e](tokens)
# 加权
weights = topk_val[mask] * (topk_idx[mask] == e).float()
weight = weights.sum(-1, keepdim=True)
out[mask] += expert_out * weight
return out.view(B, T, D)朴素版有 Python 循环开销大、负载不均的问题。生产用 Megatron-Core 或 DeepSpeed-MoE,融合内核 + 高效 all-to-all。
5.10.2 Switch-style aux loss
def switch_aux_loss(logits, top1_idx, num_experts, alpha=0.01):
probs = F.softmax(logits, dim=-1) # [N, E]
f = torch.zeros(num_experts, device=logits.device)
f.scatter_add_(0, top1_idx, torch.ones_like(top1_idx, dtype=torch.float))
f = f / probs.size(0) # 实际比例
P = probs.mean(0) # 软概率均值
return alpha * num_experts * (f * P).sum()5.10.3 Router z-loss
def router_z_loss(logits, beta=1e-3):
logsumexp = torch.logsumexp(logits, dim=-1)
return beta * (logsumexp ** 2).mean()5.10.4 DeepSeek-V3 风格的 bias-based router
class DeepSeekRouter(nn.Module):
def __init__(self, dim, num_experts, top_k, gamma=1e-3):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gamma = gamma
self.gate = nn.Linear(dim, num_experts, bias=False)
# bias 不是 nn.Parameter,是手动维护的 buffer
self.register_buffer("bias", torch.zeros(num_experts))
def forward(self, x):
logits = self.gate(x) # [N, E]
scores = torch.sigmoid(logits) # 用 sigmoid 而非 softmax
# Top-K 选择基于 scores + bias
adjusted = scores + self.bias
topk_val, topk_idx = adjusted.topk(self.top_k, dim=-1)
# gate weight 用未加 bias 的 scores
weights = scores.gather(-1, topk_idx)
weights = weights / weights.sum(-1, keepdim=True)
return topk_idx, weights, scores
@torch.no_grad()
def update_bias(self, topk_idx):
# 每 step 后调用
N = topk_idx.numel() / self.top_k
f = torch.zeros(self.num_experts, device=topk_idx.device)
f.scatter_add_(0, topk_idx.flatten(), torch.ones_like(topk_idx.flatten(), dtype=torch.float))
f = f / (N * self.top_k / self.num_experts) # 归一化到 1.0 = 平均
# 过载(f > 1)→ bias 减;欠载 → bias 加
self.bias -= self.gamma * (f - 1.0).sign()5.11 主流 MoE 模型一览
| 模型 | 总参 | 激活 | 负载均衡 | 备注 | ||
|---|---|---|---|---|---|---|
| Switch-T | 1.6T | ~7B | 32-2048 | 1 | aux + z | 编码器,T5 风格 |
| GShard | 600B | - | 2048 | 2 | aux | 翻译模型 |
| GLaM | 1.2T | 96B | 64 | 2 | aux | 稠密性能 |
| Mixtral 8×7B | 47B | 13B | 8 | 2 | aux | 完全开源 |
| Mixtral 8×22B | 141B | 39B | 8 | 2 | aux | 22B 基座 |
| DBRX | 132B | 36B | 16 | 4 | aux | Databricks |
| Arctic | 480B | 17B | 128 | 2 | aux | Snowflake |
| DeepSeek-V2 | 236B | 21B | 160 + 2 | 6 | aux + device | MLA + DeepSeekMoE |
| DeepSeek-V3 | 671B | 37B | 256 + 1 | 8 | bias-based | MLA + FP8 + MTP |
| Qwen2-MoE | 57B | 14B | 64 + 4 | 4 | aux | 阿里开源 |
| MiniMax-M1 | 456B | 45.9B | 32 | 2 | aux | 混合 attention |
5.12 MoE 的局限
5.12.1 显存仍受总参数约束
虽然 FLOPs 仅
部署上需要:
- 跨 GPU 切分(专家并行)
- 量化(FP8 / INT4)
- offload(CPU / NVMe)
5.12.2 推理动态性
每 batch 的负载分布不可预测,可能某 GPU 闲、某 GPU 忙。这给推理调度带来挑战:
- vLLM、TensorRT-LLM 都在加 MoE 优化
- DeepSpeed-MoE 推理引擎专门优化 all-to-all
- 实际部署 MoE 推理 throughput 通常低于同尺寸 dense 模型
5.12.3 训练超大 batch 的依赖
MoE 需要大 batch 才能让所有专家都被训到。Mixtral / DeepSeek-V3 都用 token batch ≥ 16M(4096 seq × 4096 micro_batch × DP)。这要求大集群。
5.12.4 微调难度
SFT、RLHF 阶段 batch 通常较小,MoE 容易出现专家激活不均,效果可能比 dense base 差。常见做法:
- 微调时冻结 router
- 上采样(同 prompt 重复几次)
- 用 expert-choice 训练以保证均衡
5.13 本章小结
- MoE 通过稀疏激活解耦参数量与 FLOPs,是当前 SOTA 模型的标配。Mixtral 47B 用 13B 激活、DeepSeek-V3 671B 用 37B 激活。
- 路由机制从 Top-1 (Switch) 演进到 Top-K (GShard, Mixtral) 再到细粒度 + 共享 (DeepSeek)。
- 负载均衡经历了 aux loss → bias-based 的演进,DeepSeek-V3 的 bias-based 几乎无 quality 损失。
- 训练稳定性靠 z-loss、FP32 路由、专家初始化。
- 专家并行 (EP) 必须配合 all-to-all 通信,DualPipe 等技术做计算-通信 overlap。
- MLA + MoE 是 DeepSeek 路线的精髓:MLA 压 KV-Cache,MoE 压 FFN,整体高效。
下一章讨论分布式训练的全景。
5.14 思考题
细粒度专家的临界点:DeepSeek-V3 用
的细粒度配置。请分析当 (同时 保持 不变)时会出现什么问题?给出一个权衡分析(如 router 的 计算量、per-expert capacity 退化等)。 辅助 loss 与主 loss 冲突:传统 MoE 的
与 共同训练,会让 router 偏向"均匀分布"而非"准确分布"。请定量分析当 、 时,辅助 loss 在总 loss 中的占比,并解释为什么 DeepSeek 的 bias-based 方法能避免此冲突。 MLA vs GQA 的工程权衡:MLA 把 KV-Cache 减小 ~50 倍,但引入了
两个解压矩阵的参数量与计算开销。设 ,请计算 MLA 在训练时 attention 部分的总参数量与 FLOPs,与同尺寸 GQA-8 对比。 专家并行的通信瓶颈:64 卡 EP,每 step 需要 460 GB 单向 all-to-all 通信,IB 带宽 50 GB/s 单链路。若不做 overlap,单 step 通信耗时 9.2 s。请提出至少 3 种 overlap 策略(参考 DualPipe),并估算理想情况下能掩盖多大比例的通信。