11. SFT 实战与工程
前三章讲清了 SFT 的「是什么、为什么、用什么数据、怎样降显存」,本章把这些组件拼装成一个完整的工程项目。我们讨论决策树、多轮对话、长上下文扩展、分布式训练、常见踩坑、超参数实战值,以及主流工具链对比。
11.1 全参 vs PEFT:决策树
不是所有任务都该用 LoRA,也不是所有任务都该全参 SFT。下面是一个决策树:
┌──────────────────────────┐
│ 你要做什么? │
└────────────┬─────────────┘
│
┌──────────────────────────┼──────────────────────────────┐
▼ ▼ ▼
大幅改变模型行为 行业知识注入 垂直任务(分类/抽取/QA)
(base → chat, (医疗/法律/金融) 样本通常 < 50K
多语言 alignment) 需要新事实
│ │ │
▼ ▼ ▼
全参 SFT ┌─ 数据 > 100K LoRA r=8-16
(大量数据, 改风格) │ 显存充足
│ │ Yes
│ ▼
│ 全参 SFT (LR ≤ 1e-5)
│
└─ 数据 ≤ 100K 或显存紧
│
▼
LoRA r=64+ 或 DoRA r=32
混合通用数据防遗忘11.1.1 何时选全参 SFT
- 从 base model 训练 chat model(如 Llama-3-base → Llama-3-Instruct):行为分布大幅改变。
- 大型 SFT 数据(> 200K,多领域混合):参数容量需求高。
- 多模态 alignment(文本 + 图像 / 音频):模态对齐需要深入修改表征。
- 算力充足(多张 H100 / A100,训练时间不是瓶颈)。
11.1.2 何时选 LoRA / QLoRA
- 垂直任务(医疗 NER、法律分类、客服 QA):本质上调整输出格式 + 少量知识注入。
- 多任务并存:一个 base + 多个 LoRA adapter,部署成本极低。
- 数据量小(< 5K):LoRA 的低参数量自带正则化,过拟合风险更低。
- 显存受限:单卡 24GB 训 7B、48GB 训 70B,QLoRA 是唯一选择。
- 快速实验:训练耗时短,迭代快。
经验法则:LoRA 适合「改风格 + 加技能」,但若要「改世界观」(注入大量新事实),仍需全参或较大 r。
11.2 多轮对话 SFT
11.2.1 拼接策略
最常见做法:把多轮对话用 chat template 拼成一整条 sequence,对每个 assistant 段落都计算 loss:
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
你好<|im_end|>
<|im_start|>assistant
你好!有什么可以帮你?<|im_end|> ← 计算 loss
<|im_start|>user
讲个笑话<|im_end|>
<|im_start|>assistant
为什么程序员喜欢黑暗?因为光会带来 bug。<|im_end|> ← 计算 loss- 优点:单 sample 训练所有 assistant 轮,效率高,模型学到「在对话上下文里如何响应」。
- 注意:所有 user/system/tool 部分必须 mask(label = -100)。
11.2.2 多轮 vs 单轮分裂
另一种策略是把多轮对话「分裂」成多个单轮样本:
样本 1: [system, user₁, assistant₁]
样本 2: [system, user₁, assistant₁, user₂, assistant₂]
样本 3: [system, user₁, assistant₁, user₂, assistant₂, user₃, assistant₃]每个样本只对最后一轮 assistant 计算 loss。
| 维度 | 整轮拼接 + 全 mask | 分裂为多个单轮 |
|---|---|---|
| 训练效率 | 高(一次 forward) | 低(重复前缀) |
| 长度浪费 | 无 | 大(前缀重复) |
| 早期 turn 学习信号 | 充分 | 充分(多次重复前缀) |
| 实现复杂度 | 简单 | 简单 |
| 推荐 | 首选 | 仅在 base model 不善于多轮时用 |
11.2.3 TRL SFTTrainer 多轮代码
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
# 每个样本已是 messages 列表格式
config = SFTConfig(
assistant_only_loss=True, # 自动对所有 assistant 段算 loss
max_seq_length=4096,
packing=False, # 多轮对话不建议 packing(除非用 4D mask)
)
trainer = SFTTrainer(model=model, args=config, train_dataset=ds, tokenizer=tokenizer)assistant_only_loss=True 会:
- 调用 chat template 时识别
{% generation %}标记。 - 对所有 assistant 段生成 mask。
- 把非 assistant 位置的 label 置为 -100。
11.2.4 处理 system prompt
- 训练时:system prompt 的内容不计算 loss(被 mask)。
- 推理时:system prompt 由用户/应用控制。
- 多样化 system prompt:训练时让 system 多样(不同人设、不同任务说明),模型才能在推理时理解新 system。
11.2.5 Sequence Packing
为减少 padding 浪费,把多个短对话打包成一条长 sequence:
原始(max_len=4096,平均每条 800 tokens):
[conv1: 800] [pad: 3296] ← 80% 浪费
[conv2: 600] [pad: 3496]
...
Packing:
[conv1: 800][conv2: 600][conv3: 1200][conv4: 900][conv5: 596] ← 0 浪费关键:packing 时必须用 attention mask 隔离样本,否则 conv1 的 token 会 attend 到 conv2,污染训练信号。
两种实现:
- 4D attention mask:构造
(B, H, T, T)mask,跨样本位置置 0。显存开销大。 - FlashAttention
cu_seqlens:传入累积序列长度数组(如[0, 800, 1400, 2600, 3500, 4096]),FlashAttn 内部按段计算。推荐。
Axolotl、TRL 都支持 packing:
config = SFTConfig(
packing=True,
max_seq_length=8192,
eval_packing=False, # eval 时关掉 packing 便于看每条样本指标
)实战收益:在 UltraChat 上 packing + flash_attn 比无 packing 训练快 2-4 倍(具体取决于长度分布)。
11.3 长上下文 SFT
预训练 4K 的模型要支持 32K/128K 上下文,需要位置编码扩展(针对 RoPE 模型)。
11.3.1 RoPE 复习
旋转位置编码(RoPE)把每个 query/key 向量按位置
其中
问题:模型只在训练长度
11.3.2 Position Interpolation(PI, Chen et al., 2023, Meta)
把 position 索引等比缩放到训练范围内:
其中
- 优点:实现简单,与训练分布最一致。
- 缺点:高频维度被「压缩」,丢失细粒度位置区分能力。
- 需要少量长上下文数据微调(< 1000 步)。最大可扩展约 8×。
11.3.3 NTK-aware(bloc97, 2023)
不直接缩放 position,而是改变 RoPE 的基底
直觉:高频维度(
NTK-aware 可以做无需微调的 zero-shot 扩展(2-4×),效果不错。
11.3.4 YaRN(Peng et al., 2023, EleutherAI / Nous)
「Yet another RoPE extensioN」综合了 NTK-by-parts 和 attention temperature scaling:
Step 1: NTK-by-parts
根据 RoPE 维度的「波长」
- 高频(
远小于上下文)→ 不插值(外推)。 - 低频(
远大于上下文)→ 全插值(PI)。 - 中间(
接近上下文)→ 平滑过渡(ramp 函数,由 控制)。
Step 2: Attention temperature scaling
因 context 变长,attention logits 分布变化(softmax 的「峰值锐度」下降),需缩放:
LLaMA 实测:
效果
YaRN 把 LLaMA-2 扩展到 128K 仅需 400-600 步微调(比 PI 减少 10×)。Mistral 7B 128K、Qwen2、Qwen2.5 都用 YaRN 作为长上下文方案。
11.3.5 LongLoRA(Chen et al., ICLR 2024)
观察:long-context SFT 显存瓶颈是 attention(
Shifted Sparse Attention(S²-Attn)
训练时把序列分成多块,每块内做 full attention,半数 head 做 shift(位移半块),模拟跨块信息交换:
Head 1-N/2: ┌──┐ ┌──┐ ┌──┐ ┌──┐ (块内 attention)
[b1] [b2] [b3] [b4]
Head N/2+1-N: shift by half block, then 块内 attention
←─[b1+b2/2]→ ←─[b2/2+b3]→ ...这种近似在训练时显存
LoRA + 全部 norm/embedding 训练
LongLoRA 配合 LoRA 训练 attention,额外让 norm 和 embedding 全部可训练(这两类参数对长上下文特别关键)。
LLaMA-2 7B 扩到 100K 仅用 8×A100。
11.4 分布式 SFT
单卡或单机已不够大模型 SFT。两大主流分布式方案:FSDP 和 DeepSpeed ZeRO。
11.4.1 FSDP(PyTorch Fully Sharded Data Parallel)
PyTorch 原生方案。把模型参数、梯度、优化器状态切片到所有 GPU:
| 模式 | 切片内容 | 等价 ZeRO |
|---|---|---|
FULL_SHARD | 参数 + 梯度 + 优化器状态 | ZeRO-3 |
SHARD_GRAD_OP | 梯度 + 优化器状态 | ZeRO-2 |
NO_SHARD | 不切(仅 DDP) | DDP |
HYBRID_SHARD | 节点内 FULL_SHARD + 节点间复制 | - |
LoRA + FSDP:base 权重切片 + LoRA adapter 各 GPU 完整副本(参数小可承受)。70B QLoRA + FSDP 可在 4×H200 上跑。
11.4.2 DeepSpeed ZeRO
微软方案,三个阶段:
- ZeRO-1:切优化器状态。
- ZeRO-2:再切梯度。
- ZeRO-3:再切参数(等价 FSDP
FULL_SHARD)。
ZeRO-3 + offload(CPU/NVMe)可把 70B 全参 SFT 跑在普通服务器(虽然慢)。但通信开销较 FSDP 大,LoRA 场景下 FSDP 通常更优。
11.4.3 配置示例
Accelerate FSDP config
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
mixed_precision: bf16
num_processes: 8
fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_offload_params: false
fsdp_use_orig_params: true # LoRA 必须 true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: true
fsdp_sync_module_states: true启动:
accelerate launch --config_file fsdp.yaml train.pyDeepSpeed ZeRO-3 config
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu", "pin_memory": true},
"offload_param": {"device": "cpu", "pin_memory": true},
"overlap_comm": true,
"contiguous_gradients": true,
"stage3_gather_16bit_weights_on_model_save": true,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9
},
"bf16": {"enabled": true},
"gradient_accumulation_steps": 4,
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": "auto"
}启动:
deepspeed --num_gpus=8 train.py --deepspeed ds_config.json11.4.4 QLoRA + FSDP 组合
QLoRA + FSDP 是 2024 年最流行的「单机 70B SFT」方案:
- base 量化为 4-bit(仅 ~35 GB / 70B)。
- FSDP 把 4-bit base 在 4-8 卡间切片。
- LoRA adapter 各卡完整副本(小,可忍)。
- 4×H100 80GB 即可微调 70B。
注意:FSDP + bitsandbytes 早期有兼容性问题(2023 年),现已稳定。HuggingFace Accelerate 提供了开箱即用的脚本。
11.5 常见问题与对策
11.5.1 灾难性遗忘(Catastrophic Forgetting)
模型 SFT 后丢失通用能力(如数学、推理变差)。
症状:
- MMLU、GSM8K、HumanEval 等通用基准下降。
- 模型只在训练数据风格上表现好,问别的就「装糊涂」。
- 多语言能力丢失。
原因:
- LR 过高,破坏预训练表示。
- 数据偏窄(如全是医疗对话)。
- 训练步数过多。
缓解策略:
| 策略 | 适用场景 | 实现 |
|---|---|---|
| 降低 LR | 始终适用 | 全参 2e-5 → 5e-6;LoRA 1-2e-4 |
| 混入通用数据 | 垂直领域 SFT | domain : general = 1 : 1 ~ 1 : 3 |
| 使用 LoRA | 显存允许时优选 | 改变小,遗忘少 |
| 早停 | 数据少时 | 监控 dev set 通用能力 + 任务表现 |
| EWC / Replay | 学术研究 | LLM 中很少使用 |
ICML 2025 有论文(Improved SFT for Mitigating Catastrophic Forgetting)提出:用基础模型自身重建通用指令分布 + 多模型生成过滤合成通用数据,与新数据混合训练。
11.5.2 过拟合(小数据场景)
- 1K-10K 样本,1-3 epoch 即可(更多通常过拟合)。
- 监控 train/eval loss,若 eval 反弹立即停。
- LR warmup 5%-10% steps 线性升温。
- LR scheduler 用 cosine decay 到 10% peak。
11.5.3 格式不合规
模型生成的输出不遵循 chat template(如不输出 <|im_end|>、漏掉 JSON 闭合括号)。
原因 1:训练数据本身格式不严谨。
对策:清洗时做格式校验(解析 JSON、检查 token 完整性)。
原因 2:训练时未对 EOS / 终止 token 计算 loss(mask 太激进)。
对策:确保 <|im_end|> 等终止 token 是 assistant 部分,被纳入 loss。
原因 3:generation 配置未对齐。
对策:推理时设置 eos_token_id 为模型实际使用的终止 token。
11.5.4 重复生成
模型生成时陷入循环(重复同一句、同一段)。
对策:
- 推理设置
repetition_penalty=1.05-1.15、no_repeat_ngram_size=3。 - 检查 SFT 数据中是否有大量重复样本(合成数据常见问题)。
- 适度提高
temperature(0.7-0.9) 和top_p(0.9-0.95)。
11.5.5 Loss 为 NaN
- 检查是否混用 FP16 / BF16 不当。BF16 更稳定,应优先。
- 检查
gradient_checkpointing_kwargs={"use_reentrant": False}(reentrant 模式有梯度问题)。 - 检查样本中是否有空 assistant 内容(mask 后整个序列全是 -100,loss 退化)。
- 学习率过高(尤其 LoRA r 大时)。
11.6 超参数实战指南
| 超参 | 全参 SFT | LoRA | QLoRA | 备注 |
|---|---|---|---|---|
| Learning Rate | 1e-5 ~ 5e-5 | 1e-4 ~ 3e-4 | 1e-4 ~ 2e-4 | LoRA 比全参高 1 个量级 |
| Batch Size (effective) | 64-256 | 16-64 | 16-32 | 通过 grad accum 实现 |
| Epochs | 1-3 | 1-5 | 1-3 | 多于 3 一般过拟合 |
| Warmup | 3-10% | 3-10% | 3-10% | linear |
| Scheduler | cosine | cosine | cosine | decay 到 peak 的 10% |
| Weight Decay | 0.0-0.1 | 0.0 | 0.0 | LoRA 不需要 |
| Max Seq Len | 4K-32K | 4K-32K | 2K-8K | 看显存 |
| Optimizer | AdamW | AdamW | paged_adamw_8bit | QLoRA 用 paged |
| Grad Clip | 1.0 | 1.0 | 0.3-1.0 | 防梯度爆炸 |
| Precision | bf16 | bf16 | bf16 (compute) | FP16 不稳定 |
| Grad Checkpointing | yes | yes | yes | 省 ~40% 显存 |
| Flash Attention | v2/v3 | v2/v3 | v2/v3 | 几乎是标配 |
经验:
- LoRA 的 LR 比全参高一个量级:因可训参数少,需要更大步长才有可见进展。
- BF16 比 FP16 更稳定:动态范围相同于 FP32(只是精度低),无 loss scaling 问题。
- Gradient checkpointing:节省显存约 40%,速度损失约 20-30%——绝大多数项目应开启。
- Flash Attention 2/3:几乎是标配,提速 + 省显存。
11.7 主流工具生态
11.7.1 HuggingFace TRL(SFTTrainer)
- 定位:研究级标准库,原生 transformers 集成。
- 优点:
- 文档全,社区活跃。
- 支持
assistant_only_loss、completion_only_loss、packing、PEFT 无缝集成。 - 同时支持 SFT / DPO / GRPO / RLOO / KTO / ORPO 等训练。
- 与 Accelerate / DeepSpeed / FSDP 无缝。
- 适合:研究、自定义训练逻辑、需要从 SFT 走到 RLHF 的端到端项目。
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(model=model, args=SFTConfig(...), train_dataset=ds)
trainer.train()11.7.2 Axolotl
- 定位:YAML 配置驱动,工业级 + 全功能。
- 优点:
- 模板丰富(每个流行模型都有官方 example YAML)。
- 完整 RLHF 流程(SFT/DPO/PPO/GRPO/RM)。
- 多模态(LLaMA-Vision、Qwen2-VL、Pixtral)。
- Sample packing、FSDP+QLoRA 一流支持。
- 适合:从 SFT 到 DPO 的端到端项目;不想写 Python 训练脚本。
YAML 示例(节选):
base_model: Qwen/Qwen3-7B
model_type: AutoModelForCausalLM
load_in_4bit: true
datasets:
- path: HuggingFaceH4/ultrachat_200k
type: chat_template
chat_template: chatml
field_messages: messages
adapter: qlora
lora_r: 64
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
micro_batch_size: 4
gradient_accumulation_steps: 4
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.05
bf16: true
gradient_checkpointing: true
flash_attention: true
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: false
fsdp_use_orig_params: true启动:
accelerate launch -m axolotl.cli.train config.yaml11.7.3 LLaMA-Factory
- 定位:Web UI + 100+ 模型支持。
- 优点:
- 零代码,GUI 入门友好。
- Unsloth 加速后端集成(约 2.1× 速度)。
- 中文社区主导,中文文档完善。
- ACL 2024 论文。
- 适合:新手入门、零代码用户、中文用户。
11.7.4 Unsloth
- 定位:单 GPU 极致性能。
- 优点:
- Triton + 手写 CUDA kernel,2-5× 提速、70-80% 显存削减。
- 单 GPU 优化最佳。
- 与 HuggingFace 生态兼容。
- 缺点:多 GPU 支持有限(持续改进中)。
- 适合:消费级 GPU(4090 / 3090)、预算有限场景。
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
"unsloth/Qwen3-7B-bnb-4bit",
max_seq_length=4096,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model, r=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16, use_gradient_checkpointing="unsloth",
)
# ... 后续与 TRL SFTTrainer 接续11.7.5 速度与显存对比
LLaMA-3.1 8B QLoRA r=16, A100 80GB, max_seq_len=4096:
| 工具 | tokens/s | 相对速度 | VRAM |
|---|---|---|---|
| Unsloth (4-bit) | ~4200 | 2.8× | ~8 GB |
| Axolotl QLoRA | ~1500 | 1.0× | ~16 GB |
| LLaMA-Factory QLoRA | ~1480 | ~1.0× | ~16 GB |
| TRL QLoRA | ~1450 | ~0.97× | ~16 GB |
实战推荐:
- 入门 / GUI:LLaMA-Factory(可后端切 Unsloth 加速)。
- 研究 / 自定义训练流程:TRL。
- 端到端项目(SFT + DPO + 评测):Axolotl。
- 消费级 GPU 单卡:Unsloth。
11.8 端到端最佳实践 Checklist
把整章总结成一个 checklist,开始 SFT 项目时逐条对照:
数据
- [ ] 数据已统一为 OpenAI Chat 格式
- [ ] 已与评测基准做去污(13-gram 检测)
- [ ] 长度分布统计,决定 max_seq_length
- [ ] 多领域按 token 数(不是样本数)配比
- [ ] 用 RM / PPL 过滤低质量样本
模型
- [ ] 选定 base model(参数规模、license、社区评分)
- [ ] 决定全参 or LoRA / QLoRA / DoRA
- [ ] LoRA 时确认 target_modules 覆盖所有 linear
- [ ] tokenizer chat template 与训练 / 推理一致
- [ ] 添加新 special token 后 resize embedding
训练
- [ ] 确认
assistant_only_loss=True(或手动 mask) - [ ] BF16 + Flash Attention 2/3 + gradient checkpointing
- [ ] LR、batch size、epoch 按 §11.6 设置
- [ ] warmup 5%、cosine scheduler
- [ ] 多卡用 FSDP
FULL_SHARD+use_orig_params=True - [ ] 长上下文用 YaRN / NTK-by-parts 扩展 RoPE
- [ ] packing 配合
cu_seqlens防跨样本污染
评测
- [ ] 对齐前后跑 MMLU、GSM8K、HumanEval 看通用能力是否退化
- [ ] 任务专项 eval(如垂直领域指标)
- [ ] 与 base model 对比 win rate(用 LLM-as-judge 或 reward model)
- [ ] 检查格式合规率(JSON 解析率、终止 token 出现率)
- [ ] 检查重复率、拒答率
部署
- [ ] LoRA 合并为完整模型(如需)
- [ ] 用 vLLM / SGLang 测试推理速度
- [ ] 与训练时 chat template 严格一致
- [ ] 设置正确的
eos_token_id、pad_token_id
11.9 本章小结
- 决策树:垂直任务 / 数据少 / 显存紧 → LoRA;行为大改 / 数据充足 / 算力够 → 全参。
- 多轮对话:整段拼接 + 全 mask(除 assistant 外),效率最高。Packing 必须配合
cu_seqlens。 - 长上下文:YaRN 是 2024-2026 年事实标准,把 LLaMA-2 扩到 128K 仅需数百步微调。
- 分布式:FSDP
FULL_SHARD+use_orig_params=True是 LoRA 多卡训练的最优配置;70B QLoRA + FSDP 单机 4×H100 即可。 - 常见坑:忘 mask、LR 过高、BF16/FP16 混用、grad checkpointing reentrant、模板不一致、未做去污。
- 工具选择:研究 TRL、工程 Axolotl、入门 LLaMA-Factory、单卡 Unsloth。
思考题
你的团队有 8×H100 80GB 集群和 100K 条高质量 SFT 数据,要训练 LLaMA-3-70B-base → chat。 给出完整的训练方案(全参 vs PEFT、batch size、LR、epoch、数据混合、长上下文方案、评测),并估算训练时间与成本。
SFT 后的模型在通用任务(MMLU、GSM8K)上下降了 5 个点,但在目标垂直任务上提升了 20 个点。这是否可以接受? 讨论何时这是「划算的交换」、何时不是;以及如何尽量缩小通用能力损失。
假设你发现训练数据中有 0.5% 的样本与 MMLU 题目高度重叠(n-gram score > 0.7),但这些样本是合成数据自带的,无法精确去除。你是直接训练并接受污染,还是花一周时间手工清洗? 量化两种选择的代价(训练时间、模型可信度、上线进度),并给出你的判断。