4. FlashAttention 深度剖析
导读:标准注意力的计算量是
,但实际瓶颈不是计算而是访存—— 的中间矩阵反复在 HBM 与 SRAM 之间倒腾。FlashAttention 通过 tiling + online softmax 把访存降到 ,在不损失精度的前提下加速 2-4 倍。本章逐步推导其原理、算法、IO 复杂度证明,以及 v1/v2/v3 的演进。
4.1 GPU 内存层级与访存瓶颈
4.1.1 内存金字塔
GPU(以 A100 为例)有多层内存:
| 级别 | 容量 | 带宽 | 延迟 |
|---|---|---|---|
| HBM (DRAM) | 80 GB | 1.5 TB/s | 300-500 ns |
| L2 cache | 40 MB | ~5 TB/s | ~150 ns |
| SRAM (per SM) | 192 KB | ~19 TB/s | ~30 ns |
| 寄存器 | 256 KB / SM | 极快 | 1 cycle |
H100 的 HBM 带宽提升到 3 TB/s,SRAM 也增至 228 KB/SM。但容量与带宽的差距依然显著——SRAM 比 HBM 快 10 倍以上,但只有它的百万分之一大小。
4.1.2 计算 vs 访存:Roofline 模型
一个 kernel 的实际运行时间
- Compute-bound:
,受 FLOPS 限制 - Memory-bound:
,受带宽限制
算术强度 (Arithmetic Intensity):
A100 的 BF16 算力 312 TFLOPS,HBM 带宽 1.5 TB/s,平衡点 AI =
4.1.3 标准注意力的算术强度分析
输入
| 步骤 | 计算 | HBM 读 | HBM 写 |
|---|---|---|---|
| 1. | |||
| 2. | |||
| 3. |
总 FLOPs:
进一步分析:占主导的访存是中间矩阵
4.1.4 显存占用:另一个隐患
标准注意力存储
LLaMA-2 7B (32 头 32 层) + 32K 上下文:
显然不可能放在显存里——所以工程上必须分块。但分块的标准实现仍要把每块的
4.2 在线 Softmax (Online Softmax)
4.2.1 安全 Softmax 的三趟实现
数值稳定的 softmax:
这需要遍历向量 3 次:
- 求
- 求
(需要先有 ) - 算
(需要先有 )
3 趟意味着对长向量需要重复读取 3 次。
4.2.2 两趟版本
把第 1、2 趟合并为一趟(同时维护
关键观察:当
证明
- 若
: ,符合定义 - 若
(因为 ):
第 3 趟(计算
4.2.3 一趟的 softmax × matmul:核心技巧
在注意力中,我们不直接需要
关键洞察:把"未归一化的输出"和"归一化系数"分开维护,最后一次性除。
定义:
最终输出:
递推:
只需 1 趟!而且每个状态都可以增量更新。
4.2.4 块级合并(Tiling 友好)
更进一步,两个块的 softmax 状态可以结合律式合并:
设块
最终输出:
这构成一个 monoid(满足结合律),可以任意分块、任意并行规约。
实现
def merge_softmax_block(m_a, l_a, O_a, m_b, l_b, O_b):
m = torch.maximum(m_a, m_b)
e_a = (m_a - m).exp()
e_b = (m_b - m).exp()
l = e_a * l_a + e_b * l_b
O = e_a.unsqueeze(-1) * O_a + e_b.unsqueeze(-1) * O_b
return m, l, O注意:
这就是 FlashAttention 的算法基础。
4.3 FlashAttention v1
4.3.1 核心思想
Dao et al. (2022) "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness":
- Tiling:把
按行/列分块,每次只在 SRAM 里处理一小块 - Online softmax:边算边维护
- 不写中间矩阵
到 HBM:只在 SRAM 中临时计算
4.3.2 算法(前向)
输入: Q, K, V ∈ R^{N×d}
输出: O ∈ R^{N×d}
设块大小 B_r, B_c 满足 B_r·d + B_c·d + B_r·B_c ≤ M (SRAM 大小)
T_r = ⌈N / B_r⌉, T_c = ⌈N / B_c⌉
将 Q 切成 T_r 个行块 {Q_1, ..., Q_{T_r}},每个 Q_i ∈ R^{B_r × d}
将 K, V 切成 T_c 个列块
外循环 i = 1..T_r:
加载 Q_i 到 SRAM
初始化 O_i ← 0 ∈ R^{B_r × d}
初始化 ℓ_i ← 0 ∈ R^{B_r}
初始化 m_i ← -∞ ∈ R^{B_r}
内循环 j = 1..T_c:
加载 K_j, V_j 到 SRAM
# 计算分块注意力分数
S_ij = Q_i · K_j^T / √d # ∈ R^{B_r × B_c}, 仅在 SRAM
# 在线 softmax 更新
m_ij = rowmax(S_ij) # ∈ R^{B_r}
m_new = max(m_i, m_ij)
P_ij = exp(S_ij - m_new[:, None]) # ∈ R^{B_r × B_c}
l_new = exp(m_i - m_new) * ℓ_i + rowsum(P_ij)
O_i = exp(m_i - m_new)[:, None] * O_i + P_ij · V_j
m_i = m_new
ℓ_i = l_new
# 循环结束后归一化
O_i = O_i / ℓ_i[:, None]
写回 O_i 到 HBM
存 L_i = m_i + log(ℓ_i)(反向用)注意:对于因果掩码,仅当
4.3.3 IO 复杂度证明
定理 (Dao 2022, Theorem 2):设 SRAM 大小
而标准注意力的 HBM 访问量为
证明思路:
- 块大小约束:
。最优选 , - 内循环每次加载
共 元素,且外循环每次都全量遍历 (共 次内循环) - 每个外循环迭代加载
共 元素 - 外循环
次:总 加载量 - 取
:总加载 加载量 ,输出 写入 ,相比之下可忽略
对比:
| 算法 | HBM 访问 | A100 (M=192KB) 比例 |
|---|---|---|
| 标准 | 1 | |
| FlashAttention |
实测在 A100 上 FlashAttention 加速 GPT-2 (N=1024, d=64) 约 3 倍。
4.3.4 反向传播
朴素反向需要
- 只存
和 (每行一个标量, 个) - 反向时重新计算
- 这是一种 selective recomputation:用
额外显存换 节省
反向 IO 复杂度同
4.3.5 显存
| 算法 | 中间显存 |
|---|---|
| 标准 | |
| FlashAttention |
这让 32K-128K 上下文成为可能。
4.3.6 实测性能
Dao 2022 报告:
- GPT-2 small (N=1024, d=64): forward+backward 加速 3.5x
- BERT-large (N=512, d=64): 端到端训练加速 15%
- 长序列 (N=16K) 内存节省 10-20x
但 FlashAttention v1 的 GPU 利用率仅 25-40%,仍未榨干硬件。
4.4 FlashAttention v2
4.4.1 v1 的瓶颈
Dao 2023 "FlashAttention-2" 分析 v1 在 H100 / A100 上跑得不够快的原因:
- 非矩阵乘 FLOPs 占比高:每内循环都要 rescale
,引入大量 element-wise 操作 - 外循环并行不足:仅按 batch × heads 并行,对小 batch + 长序列利用率差
- Warp 间通信开销:v1 把
分给不同 warp,每步要 warp 间 reduce
4.4.2 改进 1:减少 rescale
v1 每个内循环都 rescale
内循环中只维护未归一化的 \tilde{O}_i 和未除的 \ell_i:
\tilde{O}_i ← \tilde{O}_i · e^{m_i - m_new} + P_ij V_j # 仍需 rescale
\ell_i ← \ell_i · e^{m_i - m_new} + rowsum(P_ij)
循环结束后:
O_i = \tilde{O}_i / \ell_i实际上 v1 也是这种结构,但 v2 做了进一步的 IO 优化:把每个 block 的 rescale 因子算法重整理,让它能 fuse 到 GEMM 中。
4.4.3 改进 2:交换循环顺序
v1:外
v2:外
for i = 1..T_r: # Q 块(外)
for j = 1..T_c: # K 块(内,仅看 j ≤ i 的因果范围)
...
O_i = \tilde{O}_i / \ell_i
write O_i外循环可在
4.4.4 改进 3:序列并行
v1 的并行度 = batch_size × n_heads。当 batch=1(推理)或长序列训练时,GPU 可能跑不满(一个 H100 有 132 个 SM,需要至少 132 个并发 thread block)。
v2 的并行度 = batch_size × n_heads ×
4.4.5 改进 4:Warp 协作
v1 把
v2 把
4.4.6 综合效果
A100 (BF16,
| 实现 | TFLOPS | 利用率 |
|---|---|---|
| PyTorch SDPA | 78 | 25% |
| FlashAttention v1 | 124 | 40% |
| FlashAttention v2 | 226 | 73% |
| 理论峰值 | 312 | 100% |
v2 比 v1 快约 2x,比 PyTorch 标准实现快近 3x。在 LLaMA-2 7B 训练上端到端加速 25-30%。
4.5 FlashAttention v3 (Hopper)
4.5.1 H100 新特性
Hopper (H100) 引入了一系列异步特性:
- WGMMA (Warpgroup Matrix Multiply Accumulate):异步张量核指令,4 个 warp 协作发射,FLOPs 翻倍
- TMA (Tensor Memory Accelerator):硬件 DMA 引擎,专门搬运 tile,无需 warp 参与
- FP8 张量核:E4M3/E5M2 格式,1979 TFLOPS(vs BF16 989 TFLOPS)
- 更大 SRAM:每 SM 228 KB
FlashAttention v2 在 H100 上仅达 35% 峰值(BF16 ~700 TFLOPS)——因为它没有用到这些异步能力。
4.5.2 改进 1:Warp 专门化 (Producer-Consumer)
Shah et al. (2024) "FlashAttention-3" 把 warp 分两组:
- Producer warp:用 TMA 把下一块
从 HBM 加载到 SRAM(异步,与计算重叠) - Consumer warp:用 WGMMA 算当前块的
和
像 CPU 流水线一样,producer 提前预取,consumer 不等待。
4.5.3 改进 2:异步流水
更细致地,consumer 内部也分阶段:
stage 0: 算 S_{ij} = Q_i K_j^T (WGMMA1, 异步发射)
同时: producer 加载 K_{j+1}, V_{j+1}
stage 1: 等 WGMMA1 完成,算 softmax(S_{ij})
同时: WGMMA1 已完成,发射 WGMMA2 算 P_{ij} V_j
stage 2: 等 WGMMA2 完成,更新 O_iGEMM 与 softmax 重叠,softmax 与下一块 GEMM 重叠。
4.5.4 改进 3:FP8 支持
E4M3 用于前向(精度优先),E5M2 用于反向(动态范围优先)。
挑战:
- FP8 动态范围窄,注意力分数容易溢出
- softmax 的指数运算对精度敏感
解决:
- 每块缩放:每个
块在加载时做一次 scale,记录 scale factor - 混合精度累加:FP8 GEMM 的累加器仍是 FP32
- Hadamard 变换:在 Q、K 上施加随机 Hadamard 矩阵,把数值分布拉平,降低量化误差
# 概念示意(非真实代码)
H = hadamard(d) # Hadamard 矩阵, 元素 ±1/√d
Q_h = Q @ H # 旋转到 "incoherent" 基
K_h = K @ H # 同上
# Q K^T = (Q H) (K H)^T = Q K^T (H H^T = I),结果不变
# 但 Q_h, K_h 的元素分布更"均匀",FP8 量化误差小4.5.5 实测性能
H100 (SXM5, BF16/FP8):
| 实现 | BF16 TFLOPS | FP8 TFLOPS |
|---|---|---|
| FlashAttention v2 | 348 | - |
| cuDNN | 460 | - |
| FlashAttention v3 | 740 | 1200 |
| 理论峰值 | 989 | 1979 |
v3 BF16 利用率达 75%,FP8 利用率 60%。比 v2 快 2x,比 cuDNN 快 1.6x。
DeepSeek-V3 用 FlashAttention v3 + FP8 是首个公开的 FP8 大规模预训练,在 H800 上达成 ~50% MFU(密集训练里很高)。
4.6 FlashAttention 系列对比
| 版本 | GPU | BF16 利用率 | 关键创新 | 何时发布 |
|---|---|---|---|---|
| Standard | A100 | ~25% | baseline | - |
| FlashAttention v1 | A100 | 25-40% | tiling + online softmax | 2022.05 |
| FlashAttention v2 | A100 | 50-73% | 交换循环、序列并行、warp 独立 | 2023.07 |
| FlashAttention v3 | H100 | 75% (BF16), 60% (FP8) | warp 专门化、TMA、FP8 | 2024.07 |
显存:所有版本都是
4.7 实战:使用 FlashAttention
4.7.1 PyTorch 内置 SDPA
PyTorch 2.0+ 的 scaled_dot_product_attention 自动选择最优后端:
import torch.nn.functional as F
# 自动选择 FlashAttention / Memory-Efficient / Math
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
)可手动控制后端:
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)PyTorch 2.5+ 默认在 H100 上用 FlashAttention v2,未来会切到 v3。
4.7.2 flash-attn 库(独立)
Dao 团队的官方 CUDA 实现:
pip install flash-attn --no-build-isolationfrom flash_attn import flash_attn_func
# q, k, v: [batch, seq_len, n_heads, d_head]
out = flash_attn_func(q, k, v, causal=True, softmax_scale=1.0/math.sqrt(d_head))支持:
flash_attn_varlen_func:变长序列(无 padding)flash_attn_with_kvcache:推理时的 KV-Cache 接口flash_attn_qkvpacked_func:QKV 打包(减少访存)
4.7.3 Triton 实现
OpenAI Triton 提供 FlashAttention 的 Triton 实现,便于研究和定制:
# triton/python/tutorials/06-fused-attention.py
import triton
import triton.language as tl
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m, qk_scale, BLOCK_M, BLOCK_N, ...):
# 内循环:遍历 K, V 块
for start_n in range(0, ...):
k = tl.load(K_block_ptr)
qk = tl.dot(q, k) * qk_scale
m_ij = tl.maximum(m_i, tl.max(qk, 1))
p = tl.exp(qk - m_ij[:, None])
alpha = tl.exp(m_i - m_ij)
l_i = alpha * l_i + tl.sum(p, 1)
v = tl.load(V_block_ptr)
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
m_i = m_ij
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
return acc, l_i, m_i完整 Triton 实现约 200 行,性能接近 CUDA 版(80-90%)。
4.7.4 何时用何种实现
| 场景 | 推荐 |
|---|---|
| 标准训练 / 推理 | PyTorch SDPA(自动选择) |
| 极致性能 + 标准 attn | flash-attn 库 |
| 变长序列(无 padding) | flash-attn varlen_func |
| 自定义 attn 变体(ALiBi, sliding window) | Triton 实现,自己改 kernel |
| H100 + FP8 | FlashAttention v3(cuda 12+) |
4.8 局限与展望
4.8.1 不支持的特性
FlashAttention 主要针对标准 scaled dot-product。以下场景需要特殊处理:
- ALiBi:相对位置偏置加在
上,FlashAttention v2.1+ 支持 - Sliding Window:Mistral 用,FlashAttention 支持
window_size参数 - Mixture of Attention:每个 head 用不同 attention 模式,需要分别 kernel
- Cross-attention:encoder-decoder 中,K/V 来自不同序列,FlashAttention 也支持
4.8.2 极长上下文
序列长度
- Sparse Attention (BigBird, Longformer):稀疏化 attention 矩阵
- Linear Attention (Performer, Linear Transformer):核技巧近似 softmax,复杂度
- State Space Models (Mamba, S4):完全替换 attention,复杂度
- Ring Attention (Liu 2023):跨 GPU 切分序列,单卡看
长度,全局通信打通
4.8.3 推理优化的特殊性
Prefill 阶段适合 FlashAttention(长序列、并行)。
Decode 阶段每次只生成 1 个 token,序列维度退化为 1,并行度极低。这时需要特殊 kernel:
- FlashDecoding:把 K-cache 切成块,多个 SM 协作算单个 query
- PagedAttention (vLLM):页式 KV-Cache 管理,提高并发
- Group-Query Attention 友好实现:减少 KV 访存
4.9 IO 复杂度的进一步思考
4.9.1 下界
Dao 2022 证明:任何精确注意力算法的 HBM 访问下界是
FlashAttention 达到这个下界(up to constant),是 IO 最优的。
4.9.2 与 GEMM 的类比
GEMM 的最优 IO 是
理解这个,对设计任何 GPU kernel 都有启发:算术强度低的算子,要靠 tiling + 异步预取榨干带宽。
4.9.3 未来方向
- 跨层融合:attention + FFN 共享 SRAM tile,减少层间 HBM 倒腾
- Dynamic shape:变长序列动态 tile size,减少 padding 浪费
- 混合精度优化:FP4 / INT8 等低精度的 attention(仍在研究)
- ASIC 特化:Cerebras WSE、Groq 等专门硬件,把 attention 视为一等公民
4.10 本章小结
- 注意力是 memory-bound:算术强度仅
,远低于 A100 平衡点 208。瓶颈是 中间矩阵的 HBM 读写。 - Online softmax 提供了边算边归一的能力,且块级合并满足结合律——这是 tiling 可行的数学基础。
- FlashAttention v1:tiling + online softmax,HBM 访问
,比标准减少约 倍;显存从 降至 。 - FlashAttention v2:交换循环(外 Q 内 K)、序列维并行、warp 独立,A100 利用率从 40% 提升到 73%。
- FlashAttention v3:Hopper 异步特性(WGMMA、TMA、FP8、warp 专门化),H100 BF16 75%、FP8 60%。
- 生产实践:PyTorch 2.5+ SDPA、flash-attn 库、Triton 实现,按需选择。
下一章我们讨论稀疏激活的另一条路线——MoE 架构。
4.11 思考题
IO 下界推导:假设
,SRAM 大小 , 。请证明任何精确注意力算法的 HBM 访问量至少为 。提示:考虑信息论下界,每个 需要至少访问 和 各一次。 块大小设计:A100 的 SRAM 大小
KB,FP16 下每元素 2 字节。 ,FlashAttention v2 应取 多大?分别考虑约束 (除以 2 留余量给寄存器和栈)。给出至少 2 组合理配置并比较。 FP8 误差分析:FP8 E4M3 的最大表示数
,最小正规数 。设 , ,则 的分布是什么?该分布在 FP8 下如何缩放才能避免溢出?请定量说明 per-block scaling 的必要性。