Skip to content

11. SFT 实战与工程

前三章讲清了 SFT 的「是什么、为什么、用什么数据、怎样降显存」,本章把这些组件拼装成一个完整的工程项目。我们讨论决策树、多轮对话、长上下文扩展、分布式训练、常见踩坑、超参数实战值,以及主流工具链对比。

11.1 全参 vs PEFT:决策树

不是所有任务都该用 LoRA,也不是所有任务都该全参 SFT。下面是一个决策树:

                            ┌──────────────────────────┐
                            │  你要做什么?             │
                            └────────────┬─────────────┘

              ┌──────────────────────────┼──────────────────────────────┐
              ▼                          ▼                              ▼
   大幅改变模型行为                 行业知识注入              垂直任务(分类/抽取/QA)
   (base → chat,                   (医疗/法律/金融)         样本通常 < 50K
    多语言 alignment)               需要新事实
              │                          │                              │
              ▼                          ▼                              ▼
        全参 SFT                   ┌─ 数据 > 100K          LoRA r=8-16
   (大量数据, 改风格)             │  显存充足
                                  │       │ Yes
                                  │       ▼
                                  │   全参 SFT (LR ≤ 1e-5)

                                  └─ 数据 ≤ 100K 或显存紧


                                  LoRA r=64+ 或 DoRA r=32
                                  混合通用数据防遗忘

11.1.1 何时选全参 SFT

  • 从 base model 训练 chat model(如 Llama-3-base → Llama-3-Instruct):行为分布大幅改变。
  • 大型 SFT 数据(> 200K,多领域混合):参数容量需求高。
  • 多模态 alignment(文本 + 图像 / 音频):模态对齐需要深入修改表征。
  • 算力充足(多张 H100 / A100,训练时间不是瓶颈)。

11.1.2 何时选 LoRA / QLoRA

  • 垂直任务(医疗 NER、法律分类、客服 QA):本质上调整输出格式 + 少量知识注入。
  • 多任务并存:一个 base + 多个 LoRA adapter,部署成本极低。
  • 数据量小(< 5K):LoRA 的低参数量自带正则化,过拟合风险更低。
  • 显存受限:单卡 24GB 训 7B、48GB 训 70B,QLoRA 是唯一选择。
  • 快速实验:训练耗时短,迭代快。

经验法则:LoRA 适合「改风格 + 加技能」,但若要「改世界观」(注入大量新事实),仍需全参或较大 r

11.2 多轮对话 SFT

11.2.1 拼接策略

最常见做法:把多轮对话用 chat template 拼成一整条 sequence,对每个 assistant 段落都计算 loss:

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
你好<|im_end|>
<|im_start|>assistant
你好!有什么可以帮你?<|im_end|>          ← 计算 loss
<|im_start|>user
讲个笑话<|im_end|>
<|im_start|>assistant
为什么程序员喜欢黑暗?因为光会带来 bug。<|im_end|>  ← 计算 loss
  • 优点:单 sample 训练所有 assistant 轮,效率高,模型学到「在对话上下文里如何响应」。
  • 注意:所有 user/system/tool 部分必须 mask(label = -100)。

11.2.2 多轮 vs 单轮分裂

另一种策略是把多轮对话「分裂」成多个单轮样本:

样本 1: [system, user₁, assistant₁]
样本 2: [system, user₁, assistant₁, user₂, assistant₂]
样本 3: [system, user₁, assistant₁, user₂, assistant₂, user₃, assistant₃]

每个样本只对最后一轮 assistant 计算 loss。

维度整轮拼接 + 全 mask分裂为多个单轮
训练效率高(一次 forward)低(重复前缀)
长度浪费大(前缀重复)
早期 turn 学习信号充分充分(多次重复前缀)
实现复杂度简单简单
推荐首选仅在 base model 不善于多轮时用

11.2.3 TRL SFTTrainer 多轮代码

python
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig

ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
# 每个样本已是 messages 列表格式

config = SFTConfig(
    assistant_only_loss=True,    # 自动对所有 assistant 段算 loss
    max_seq_length=4096,
    packing=False,               # 多轮对话不建议 packing(除非用 4D mask)
)
trainer = SFTTrainer(model=model, args=config, train_dataset=ds, tokenizer=tokenizer)

assistant_only_loss=True 会:

  1. 调用 chat template 时识别 {% generation %} 标记。
  2. 对所有 assistant 段生成 mask。
  3. 把非 assistant 位置的 label 置为 -100。

11.2.4 处理 system prompt

  • 训练时:system prompt 的内容不计算 loss(被 mask)。
  • 推理时:system prompt 由用户/应用控制。
  • 多样化 system prompt:训练时让 system 多样(不同人设、不同任务说明),模型才能在推理时理解新 system。

11.2.5 Sequence Packing

为减少 padding 浪费,把多个短对话打包成一条长 sequence:

原始(max_len=4096,平均每条 800 tokens):
  [conv1: 800] [pad: 3296]  ← 80% 浪费
  [conv2: 600] [pad: 3496]
  ...

Packing:
  [conv1: 800][conv2: 600][conv3: 1200][conv4: 900][conv5: 596]  ← 0 浪费

关键:packing 时必须用 attention mask 隔离样本,否则 conv1 的 token 会 attend 到 conv2,污染训练信号。

两种实现:

  1. 4D attention mask:构造 (B, H, T, T) mask,跨样本位置置 0。显存开销大。
  2. FlashAttention cu_seqlens:传入累积序列长度数组(如 [0, 800, 1400, 2600, 3500, 4096]),FlashAttn 内部按段计算。推荐

Axolotl、TRL 都支持 packing:

python
config = SFTConfig(
    packing=True,
    max_seq_length=8192,
    eval_packing=False,    # eval 时关掉 packing 便于看每条样本指标
)

实战收益:在 UltraChat 上 packing + flash_attn 比无 packing 训练快 2-4 倍(具体取决于长度分布)。

11.3 长上下文 SFT

预训练 4K 的模型要支持 32K/128K 上下文,需要位置编码扩展(针对 RoPE 模型)。

11.3.1 RoPE 复习

旋转位置编码(RoPE)把每个 query/key 向量按位置 m 旋转:

qm=R(mθ)q,kn=R(nθ)k

其中 R(ϕ) 是 2D 旋转矩阵,θi=100002i/d。点积 qmkn 仅依赖相对位置 mn

问题:模型只在训练长度 L 内见过相对位置,超出 L 后高频维度旋转角度会进入「未见过」区域 → 性能崩溃。

11.3.2 Position Interpolation(PI, Chen et al., 2023, Meta)

把 position 索引等比缩放到训练范围内:

f(xm,m,θ)=f(xm,mL/L,θ)

其中 L 是预训练长度,L 是目标长度。例如 L=4K,L=32K,则把所有位置 m 除以 8。

  • 优点:实现简单,与训练分布最一致。
  • 缺点:高频维度被「压缩」,丢失细粒度位置区分能力。
  • 需要少量长上下文数据微调(< 1000 步)。最大可扩展约 8×。

11.3.3 NTK-aware(bloc97, 2023)

不直接缩放 position,而是改变 RoPE 的基底 θ

θi=θs2i/(d2),s=(L/L)d/(d2)

直觉:高频维度(θi 大)保持原样不缩放(外推),低频维度(θi 小)大幅缩放(插值)。这样避免高频信息丢失

NTK-aware 可以做无需微调的 zero-shot 扩展(2-4×),效果不错。

11.3.4 YaRN(Peng et al., 2023, EleutherAI / Nous)

「Yet another RoPE extensioN」综合了 NTK-by-parts 和 attention temperature scaling:

Step 1: NTK-by-parts

根据 RoPE 维度的「波长」λi=2π/θi 分段处理:

  • 高频λi 远小于上下文)→ 不插值(外推)。
  • 低频λi 远大于上下文)→ 全插值(PI)。
  • 中间λi 接近上下文)→ 平滑过渡(ramp 函数,由 βfast,βslow 控制)。

Step 2: Attention temperature scaling

因 context 变长,attention logits 分布变化(softmax 的「峰值锐度」下降),需缩放:

1t=aln(s)+b,s=L/L

LLaMA 实测:a=0.1,b=1

效果

YaRN 把 LLaMA-2 扩展到 128K 仅需 400-600 步微调(比 PI 减少 10×)。Mistral 7B 128K、Qwen2、Qwen2.5 都用 YaRN 作为长上下文方案。

11.3.5 LongLoRA(Chen et al., ICLR 2024)

观察:long-context SFT 显存瓶颈是 attention(O(L2))。LongLoRA 提出:

Shifted Sparse Attention(S²-Attn)

训练时把序列分成多块,每块内做 full attention,半数 head 做 shift(位移半块),模拟跨块信息交换:

Head 1-N/2:  ┌──┐ ┌──┐ ┌──┐ ┌──┐    (块内 attention)
              [b1] [b2] [b3] [b4]

Head N/2+1-N: shift by half block, then 块内 attention
              ←─[b1+b2/2]→ ←─[b2/2+b3]→ ...

这种近似在训练时显存 O(Lc)c 是块大小),推理时仍用 full attention。

LoRA + 全部 norm/embedding 训练

LongLoRA 配合 LoRA 训练 attention,额外让 norm 和 embedding 全部可训练(这两类参数对长上下文特别关键)。

LLaMA-2 7B 扩到 100K 仅用 8×A100。

11.4 分布式 SFT

单卡或单机已不够大模型 SFT。两大主流分布式方案:FSDP 和 DeepSpeed ZeRO。

11.4.1 FSDP(PyTorch Fully Sharded Data Parallel)

PyTorch 原生方案。把模型参数、梯度、优化器状态切片到所有 GPU:

模式切片内容等价 ZeRO
FULL_SHARD参数 + 梯度 + 优化器状态ZeRO-3
SHARD_GRAD_OP梯度 + 优化器状态ZeRO-2
NO_SHARD不切(仅 DDP)DDP
HYBRID_SHARD节点内 FULL_SHARD + 节点间复制-

LoRA + FSDP:base 权重切片 + LoRA adapter 各 GPU 完整副本(参数小可承受)。70B QLoRA + FSDP 可在 4×H200 上跑。

11.4.2 DeepSpeed ZeRO

微软方案,三个阶段:

  • ZeRO-1:切优化器状态。
  • ZeRO-2:再切梯度。
  • ZeRO-3:再切参数(等价 FSDP FULL_SHARD)。

ZeRO-3 + offload(CPU/NVMe)可把 70B 全参 SFT 跑在普通服务器(虽然慢)。但通信开销较 FSDP 大,LoRA 场景下 FSDP 通常更优

11.4.3 配置示例

Accelerate FSDP config

yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
mixed_precision: bf16
num_processes: 8
fsdp_config:
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_offload_params: false
  fsdp_use_orig_params: true       # LoRA 必须 true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_sync_module_states: true

启动:

bash
accelerate launch --config_file fsdp.yaml train.py

DeepSpeed ZeRO-3 config

json
{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {"device": "cpu", "pin_memory": true},
    "offload_param": {"device": "cpu", "pin_memory": true},
    "overlap_comm": true,
    "contiguous_gradients": true,
    "stage3_gather_16bit_weights_on_model_save": true,
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9
  },
  "bf16": {"enabled": true},
  "gradient_accumulation_steps": 4,
  "gradient_clipping": 1.0,
  "train_micro_batch_size_per_gpu": "auto"
}

启动:

bash
deepspeed --num_gpus=8 train.py --deepspeed ds_config.json

11.4.4 QLoRA + FSDP 组合

QLoRA + FSDP 是 2024 年最流行的「单机 70B SFT」方案:

  • base 量化为 4-bit(仅 ~35 GB / 70B)。
  • FSDP 把 4-bit base 在 4-8 卡间切片。
  • LoRA adapter 各卡完整副本(小,可忍)。
  • 4×H100 80GB 即可微调 70B。

注意:FSDP + bitsandbytes 早期有兼容性问题(2023 年),现已稳定。HuggingFace Accelerate 提供了开箱即用的脚本。

11.5 常见问题与对策

11.5.1 灾难性遗忘(Catastrophic Forgetting)

模型 SFT 后丢失通用能力(如数学、推理变差)。

症状

  • MMLU、GSM8K、HumanEval 等通用基准下降。
  • 模型只在训练数据风格上表现好,问别的就「装糊涂」。
  • 多语言能力丢失。

原因

  • LR 过高,破坏预训练表示。
  • 数据偏窄(如全是医疗对话)。
  • 训练步数过多。

缓解策略

策略适用场景实现
降低 LR始终适用全参 2e-5 → 5e-6;LoRA 1-2e-4
混入通用数据垂直领域 SFTdomain : general = 1 : 1 ~ 1 : 3
使用 LoRA显存允许时优选改变小,遗忘少
早停数据少时监控 dev set 通用能力 + 任务表现
EWC / Replay学术研究LLM 中很少使用

ICML 2025 有论文(Improved SFT for Mitigating Catastrophic Forgetting)提出:用基础模型自身重建通用指令分布 + 多模型生成过滤合成通用数据,与新数据混合训练。

11.5.2 过拟合(小数据场景)

  • 1K-10K 样本,1-3 epoch 即可(更多通常过拟合)。
  • 监控 train/eval loss,若 eval 反弹立即停。
  • LR warmup 5%-10% steps 线性升温。
  • LR scheduler 用 cosine decay 到 10% peak。

11.5.3 格式不合规

模型生成的输出不遵循 chat template(如不输出 <|im_end|>、漏掉 JSON 闭合括号)。

原因 1:训练数据本身格式不严谨。
对策:清洗时做格式校验(解析 JSON、检查 token 完整性)。

原因 2:训练时未对 EOS / 终止 token 计算 loss(mask 太激进)。
对策:确保 <|im_end|> 等终止 token 是 assistant 部分,被纳入 loss。

原因 3:generation 配置未对齐。
对策:推理时设置 eos_token_id 为模型实际使用的终止 token。

11.5.4 重复生成

模型生成时陷入循环(重复同一句、同一段)。

对策

  • 推理设置 repetition_penalty=1.05-1.15no_repeat_ngram_size=3
  • 检查 SFT 数据中是否有大量重复样本(合成数据常见问题)。
  • 适度提高 temperature (0.7-0.9) 和 top_p (0.9-0.95)。

11.5.5 Loss 为 NaN

  • 检查是否混用 FP16 / BF16 不当。BF16 更稳定,应优先。
  • 检查 gradient_checkpointing_kwargs={"use_reentrant": False}(reentrant 模式有梯度问题)。
  • 检查样本中是否有空 assistant 内容(mask 后整个序列全是 -100,loss 退化)。
  • 学习率过高(尤其 LoRA r 大时)。

11.6 超参数实战指南

超参全参 SFTLoRAQLoRA备注
Learning Rate1e-5 ~ 5e-51e-4 ~ 3e-41e-4 ~ 2e-4LoRA 比全参高 1 个量级
Batch Size (effective)64-25616-6416-32通过 grad accum 实现
Epochs1-31-51-3多于 3 一般过拟合
Warmup3-10%3-10%3-10%linear
Schedulercosinecosinecosinedecay 到 peak 的 10%
Weight Decay0.0-0.10.00.0LoRA 不需要
Max Seq Len4K-32K4K-32K2K-8K看显存
OptimizerAdamWAdamWpaged_adamw_8bitQLoRA 用 paged
Grad Clip1.01.00.3-1.0防梯度爆炸
Precisionbf16bf16bf16 (compute)FP16 不稳定
Grad Checkpointingyesyesyes省 ~40% 显存
Flash Attentionv2/v3v2/v3v2/v3几乎是标配

经验:

  • LoRA 的 LR 比全参高一个量级:因可训参数少,需要更大步长才有可见进展。
  • BF16 比 FP16 更稳定:动态范围相同于 FP32(只是精度低),无 loss scaling 问题。
  • Gradient checkpointing:节省显存约 40%,速度损失约 20-30%——绝大多数项目应开启。
  • Flash Attention 2/3:几乎是标配,提速 + 省显存。

11.7 主流工具生态

11.7.1 HuggingFace TRL(SFTTrainer

  • 定位:研究级标准库,原生 transformers 集成。
  • 优点
    • 文档全,社区活跃。
    • 支持 assistant_only_losscompletion_only_loss、packing、PEFT 无缝集成。
    • 同时支持 SFT / DPO / GRPO / RLOO / KTO / ORPO 等训练。
    • 与 Accelerate / DeepSpeed / FSDP 无缝。
  • 适合:研究、自定义训练逻辑、需要从 SFT 走到 RLHF 的端到端项目。
python
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(model=model, args=SFTConfig(...), train_dataset=ds)
trainer.train()

11.7.2 Axolotl

  • 定位:YAML 配置驱动,工业级 + 全功能。
  • 优点
    • 模板丰富(每个流行模型都有官方 example YAML)。
    • 完整 RLHF 流程(SFT/DPO/PPO/GRPO/RM)。
    • 多模态(LLaMA-Vision、Qwen2-VL、Pixtral)。
    • Sample packing、FSDP+QLoRA 一流支持
  • 适合:从 SFT 到 DPO 的端到端项目;不想写 Python 训练脚本。

YAML 示例(节选):

yaml
base_model: Qwen/Qwen3-7B
model_type: AutoModelForCausalLM
load_in_4bit: true

datasets:
  - path: HuggingFaceH4/ultrachat_200k
    type: chat_template
    chat_template: chatml
    field_messages: messages

adapter: qlora
lora_r: 64
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

micro_batch_size: 4
gradient_accumulation_steps: 4
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.05

bf16: true
gradient_checkpointing: true
flash_attention: true

fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_offload_params: false
  fsdp_use_orig_params: true

启动:

bash
accelerate launch -m axolotl.cli.train config.yaml

11.7.3 LLaMA-Factory

  • 定位:Web UI + 100+ 模型支持。
  • 优点
    • 零代码,GUI 入门友好。
    • Unsloth 加速后端集成(约 2.1× 速度)。
    • 中文社区主导,中文文档完善。
    • ACL 2024 论文。
  • 适合:新手入门、零代码用户、中文用户。

11.7.4 Unsloth

  • 定位:单 GPU 极致性能。
  • 优点
    • Triton + 手写 CUDA kernel,2-5× 提速、70-80% 显存削减
    • 单 GPU 优化最佳。
    • 与 HuggingFace 生态兼容。
  • 缺点:多 GPU 支持有限(持续改进中)。
  • 适合:消费级 GPU(4090 / 3090)、预算有限场景。
python
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    "unsloth/Qwen3-7B-bnb-4bit",
    max_seq_length=4096,
    load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
    model, r=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                                  "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16, use_gradient_checkpointing="unsloth",
)
# ... 后续与 TRL SFTTrainer 接续

11.7.5 速度与显存对比

LLaMA-3.1 8B QLoRA r=16, A100 80GB, max_seq_len=4096:

工具tokens/s相对速度VRAM
Unsloth (4-bit)~42002.8×~8 GB
Axolotl QLoRA~15001.0×~16 GB
LLaMA-Factory QLoRA~1480~1.0×~16 GB
TRL QLoRA~1450~0.97×~16 GB

实战推荐

  • 入门 / GUI:LLaMA-Factory(可后端切 Unsloth 加速)。
  • 研究 / 自定义训练流程:TRL。
  • 端到端项目(SFT + DPO + 评测):Axolotl。
  • 消费级 GPU 单卡:Unsloth。

11.8 端到端最佳实践 Checklist

把整章总结成一个 checklist,开始 SFT 项目时逐条对照:

数据

  • [ ] 数据已统一为 OpenAI Chat 格式
  • [ ] 已与评测基准做去污(13-gram 检测)
  • [ ] 长度分布统计,决定 max_seq_length
  • [ ] 多领域按 token 数(不是样本数)配比
  • [ ] 用 RM / PPL 过滤低质量样本

模型

  • [ ] 选定 base model(参数规模、license、社区评分)
  • [ ] 决定全参 or LoRA / QLoRA / DoRA
  • [ ] LoRA 时确认 target_modules 覆盖所有 linear
  • [ ] tokenizer chat template 与训练 / 推理一致
  • [ ] 添加新 special token 后 resize embedding

训练

  • [ ] 确认 assistant_only_loss=True(或手动 mask)
  • [ ] BF16 + Flash Attention 2/3 + gradient checkpointing
  • [ ] LR、batch size、epoch 按 §11.6 设置
  • [ ] warmup 5%、cosine scheduler
  • [ ] 多卡用 FSDP FULL_SHARD + use_orig_params=True
  • [ ] 长上下文用 YaRN / NTK-by-parts 扩展 RoPE
  • [ ] packing 配合 cu_seqlens 防跨样本污染

评测

  • [ ] 对齐前后跑 MMLU、GSM8K、HumanEval 看通用能力是否退化
  • [ ] 任务专项 eval(如垂直领域指标)
  • [ ] 与 base model 对比 win rate(用 LLM-as-judge 或 reward model)
  • [ ] 检查格式合规率(JSON 解析率、终止 token 出现率)
  • [ ] 检查重复率、拒答率

部署

  • [ ] LoRA 合并为完整模型(如需)
  • [ ] 用 vLLM / SGLang 测试推理速度
  • [ ] 与训练时 chat template 严格一致
  • [ ] 设置正确的 eos_token_idpad_token_id

11.9 本章小结

  • 决策树:垂直任务 / 数据少 / 显存紧 → LoRA;行为大改 / 数据充足 / 算力够 → 全参。
  • 多轮对话:整段拼接 + 全 mask(除 assistant 外),效率最高。Packing 必须配合 cu_seqlens
  • 长上下文:YaRN 是 2024-2026 年事实标准,把 LLaMA-2 扩到 128K 仅需数百步微调。
  • 分布式:FSDP FULL_SHARD + use_orig_params=True 是 LoRA 多卡训练的最优配置;70B QLoRA + FSDP 单机 4×H100 即可。
  • 常见坑:忘 mask、LR 过高、BF16/FP16 混用、grad checkpointing reentrant、模板不一致、未做去污。
  • 工具选择:研究 TRL、工程 Axolotl、入门 LLaMA-Factory、单卡 Unsloth。

思考题

  1. 你的团队有 8×H100 80GB 集群和 100K 条高质量 SFT 数据,要训练 LLaMA-3-70B-base → chat。 给出完整的训练方案(全参 vs PEFT、batch size、LR、epoch、数据混合、长上下文方案、评测),并估算训练时间与成本。

  2. SFT 后的模型在通用任务(MMLU、GSM8K)上下降了 5 个点,但在目标垂直任务上提升了 20 个点。这是否可以接受? 讨论何时这是「划算的交换」、何时不是;以及如何尽量缩小通用能力损失。

  3. 假设你发现训练数据中有 0.5% 的样本与 MMLU 题目高度重叠(n-gram score > 0.7),但这些样本是合成数据自带的,无法精确去除。你是直接训练并接受污染,还是花一周时间手工清洗? 量化两种选择的代价(训练时间、模型可信度、上线进度),并给出你的判断。

基于 MIT 协议发布