Skip to content

4. FlashAttention 深度剖析

导读:标准注意力的计算量是 O(N2d),但实际瓶颈不是计算而是访存——N×N 的中间矩阵反复在 HBM 与 SRAM 之间倒腾。FlashAttention 通过 tiling + online softmax 把访存降到 O(N2d2/M),在不损失精度的前提下加速 2-4 倍。本章逐步推导其原理、算法、IO 复杂度证明,以及 v1/v2/v3 的演进。

4.1 GPU 内存层级与访存瓶颈

4.1.1 内存金字塔

GPU(以 A100 为例)有多层内存:

级别容量带宽延迟
HBM (DRAM)80 GB1.5 TB/s300-500 ns
L2 cache40 MB~5 TB/s~150 ns
SRAM (per SM)192 KB~19 TB/s~30 ns
寄存器256 KB / SM极快1 cycle

H100 的 HBM 带宽提升到 3 TB/s,SRAM 也增至 228 KB/SM。但容量与带宽的差距依然显著——SRAM 比 HBM 快 10 倍以上,但只有它的百万分之一大小。

4.1.2 计算 vs 访存:Roofline 模型

一个 kernel 的实际运行时间 T=max(Tcompute,Tmemory)

  • Compute-boundTcompute>Tmemory,受 FLOPS 限制
  • Memory-boundTmemory>Tcompute,受带宽限制

算术强度 (Arithmetic Intensity)

AI=FLOPsBytes accessed

A100 的 BF16 算力 312 TFLOPS,HBM 带宽 1.5 TB/s,平衡点 AI = 312000/1500=208 FLOPs/byte。AI 低于 208 → memory-bound,反之 compute-bound。

4.1.3 标准注意力的算术强度分析

输入 Q,K,VRN×dN 序列长度,d 头维度(如 64)。

步骤计算HBM 读HBM 写
1. S=QK2N2dQ+K=2NdS=N2
2. P=softmax(S)3N2S=N2P=N2
3. O=PV2N2dP+V=N2+NdO=Nd

总 FLOPs4N2d总 HBM 访存4N2+4Nd(FP16,乘以 2 字节)

AIAttn=4N2d2(4N2+4Nd)=N2d2(N2+Nd)d2 (when Nd)

d=64 → AI ≈ 32,远低于 A100 平衡点 208。注意力是 memory-bound

进一步分析:占主导的访存是中间矩阵 S,P2×2×N2 字节)。消除 SP 的 HBM 读写就是 FlashAttention 的核心动机

4.1.4 显存占用:另一个隐患

标准注意力存储 S,P:每层每头 N22 字节 + N22 字节 = 4N2 字节。

LLaMA-2 7B (32 头 32 层) + 32K 上下文:

32324327682=4.4 TB

显然不可能放在显存里——所以工程上必须分块。但分块的标准实现仍要把每块的 S,P 写回 HBM 再读,访存代价高昂。


4.2 在线 Softmax (Online Softmax)

4.2.1 安全 Softmax 的三趟实现

数值稳定的 softmax:

m=maxjxj,d=jexjm,yj=exjmd

这需要遍历向量 3 次:

  1. m
  2. d(需要先有 m
  3. yj(需要先有 m,d

3 趟意味着对长向量需要重复读取 3 次。

4.2.2 两趟版本

把第 1、2 趟合并为一趟(同时维护 md):

m(i)=max(m(i1),xi)d(i)=d(i1)em(i1)m(i)+exim(i)

关键观察:当 m 更新时,旧的 d(i1) 中的指数偏移也要更新,乘以 em(i1)m(i)1

证明 d(N)=jexjm(N)(对结果归纳):

  • m(i)=m(i1)d(i)=d(i1)+exim(i1),符合定义
  • m(i)>m(i1)(因为 m(i)=xi):d(i)=d(i1)em(i1)m(i)+1=j<iexjm(i)+exim(i)=jiexjm(i)

第 3 趟(计算 yj)必须在 m,d 完全确定后才能做。

4.2.3 一趟的 softmax × matmul:核心技巧

在注意力中,我们不直接需要 yj,而是需要 jyjvj(即 AV 的某行)。

关键洞察:把"未归一化的输出"和"归一化系数"分开维护,最后一次性除

定义:

O~(i)=jiexjm(i)vj,(i)=jiexjm(i)

最终输出:O=O~(N)/(N)

递推:

O~(i)=O~(i1)em(i1)m(i)+exim(i)vi(i)=(i1)em(i1)m(i)+exim(i)

只需 1 趟!而且每个状态都可以增量更新。

4.2.4 块级合并(Tiling 友好)

更进一步,两个块的 softmax 状态可以结合律式合并

设块 A(mA,A,O~A)(即 mA=maxjAxjA=jAexjmAO~A=jAexjmAvj);块 B(mB,B,O~B)。则合并后的 AB 状态:

m=max(mA,mB)=AemAm+BemBmO~=O~AemAm+O~BemBm

最终输出:O=O~/

这构成一个 monoid(满足结合律),可以任意分块、任意并行规约。

实现

python
def merge_softmax_block(m_a, l_a, O_a, m_b, l_b, O_b):
    m = torch.maximum(m_a, m_b)
    e_a = (m_a - m).exp()
    e_b = (m_b - m).exp()
    l = e_a * l_a + e_b * l_b
    O = e_a.unsqueeze(-1) * O_a + e_b.unsqueeze(-1) * O_b
    return m, l, O

注意:Oa 这里是 O~a(未除以 a),最后才除一次。

这就是 FlashAttention 的算法基础。


4.3 FlashAttention v1

4.3.1 核心思想

Dao et al. (2022) "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness":

  1. Tiling:把 Q,K,V 按行/列分块,每次只在 SRAM 里处理一小块
  2. Online softmax:边算边维护 (m,,O~)
  3. 不写中间矩阵 S,P 到 HBM:只在 SRAM 中临时计算

4.3.2 算法(前向)

输入: Q, K, V ∈ R^{N×d}
输出: O ∈ R^{N×d}

设块大小 B_r, B_c 满足 B_r·d + B_c·d + B_r·B_c ≤ M (SRAM 大小)
T_r = ⌈N / B_r⌉, T_c = ⌈N / B_c⌉

将 Q 切成 T_r 个行块 {Q_1, ..., Q_{T_r}},每个 Q_i ∈ R^{B_r × d}
将 K, V 切成 T_c 个列块

外循环 i = 1..T_r:
    加载 Q_i 到 SRAM
    初始化 O_i ← 0 ∈ R^{B_r × d}
    初始化 ℓ_i ← 0 ∈ R^{B_r}
    初始化 m_i ← -∞ ∈ R^{B_r}

    内循环 j = 1..T_c:
        加载 K_j, V_j 到 SRAM
        # 计算分块注意力分数
        S_ij = Q_i · K_j^T / √d           # ∈ R^{B_r × B_c}, 仅在 SRAM
        # 在线 softmax 更新
        m_ij = rowmax(S_ij)                # ∈ R^{B_r}
        m_new = max(m_i, m_ij)
        P_ij = exp(S_ij - m_new[:, None])  # ∈ R^{B_r × B_c}
        l_new = exp(m_i - m_new) * ℓ_i + rowsum(P_ij)
        O_i = exp(m_i - m_new)[:, None] * O_i + P_ij · V_j
        m_i = m_new
        ℓ_i = l_new

    # 循环结束后归一化
    O_i = O_i / ℓ_i[:, None]
    写回 O_i 到 HBM
    存 L_i = m_i + log(ℓ_i)(反向用)

注意:对于因果掩码,仅当 jBciBr+Br 时进入内循环(半三角)。

4.3.3 IO 复杂度证明

定理 (Dao 2022, Theorem 2):设 SRAM 大小 MdM。FlashAttention 的 HBM 访问量为

O(N2d2M)

而标准注意力的 HBM 访问量为 Θ(Nd+N2)

证明思路

  • 块大小约束:Brd+Bcd+BrBcM。最优选 Bc=Θ(M/d)Br=min(Θ(M/d),d)
  • 内循环每次加载 Kj,Vj2Bcd 元素,且外循环每次都全量遍历 K,V(共 Tc 次内循环)
  • 每个外循环迭代加载 K,V2Nd 元素
  • 外循环 Tr 次:总 K,V 加载量 Θ(TrNd)=Θ(NBrNd)
  • Br=Θ(M/d):总加载 Θ(N2d2M)
  • Q 加载量 Θ(Nd),输出 O 写入 Θ(Nd),相比之下可忽略

对比

算法HBM 访问A100 (M=192KB) 比例
标准Θ(N2)1
FlashAttentionΘ(N2d2/M)d2/M

d=64d2=4096M=192×1024=196608,比例 ≈ 0.02 → 访存少 50 倍

实测在 A100 上 FlashAttention 加速 GPT-2 (N=1024, d=64) 约 3 倍

4.3.4 反向传播

朴素反向需要 S,P 矩阵。FlashAttention 不存它们,而是:

  • 只存 OL=m+log(每行一个标量,N 个)
  • 反向时重新计算 Pij=exp(SijLi)
  • 这是一种 selective recomputation:用 O(N) 额外显存换 O(N2) 节省

反向 IO 复杂度同 O(N2d2/M)。计算量增加约 25%(因为多算一次 forward),但因为它仍是 memory-bound,实际几乎无墙时间损失。

4.3.5 显存

算法中间显存
标准O(N2)S,P 矩阵)
FlashAttentionO(N)(仅 L,m, 标量)

这让 32K-128K 上下文成为可能。

4.3.6 实测性能

Dao 2022 报告:

  • GPT-2 small (N=1024, d=64): forward+backward 加速 3.5x
  • BERT-large (N=512, d=64): 端到端训练加速 15%
  • 长序列 (N=16K) 内存节省 10-20x

但 FlashAttention v1 的 GPU 利用率仅 25-40%,仍未榨干硬件。


4.4 FlashAttention v2

4.4.1 v1 的瓶颈

Dao 2023 "FlashAttention-2" 分析 v1 在 H100 / A100 上跑得不够快的原因:

  1. 非矩阵乘 FLOPs 占比高:每内循环都要 rescale O,引入大量 element-wise 操作
  2. 外循环并行不足:仅按 batch × heads 并行,对小 batch + 长序列利用率差
  3. Warp 间通信开销:v1 把 K,V 分给不同 warp,每步要 warp 间 reduce

4.4.2 改进 1:减少 rescale

v1 每个内循环都 rescale OiOi=emimnew)。v2 改为只在外循环结束时除一次 i

内循环中只维护未归一化的 \tilde{O}_i 和未除的 \ell_i:
    \tilde{O}_i ← \tilde{O}_i · e^{m_i - m_new} + P_ij V_j      # 仍需 rescale
    \ell_i     ← \ell_i · e^{m_i - m_new} + rowsum(P_ij)
循环结束后:
    O_i = \tilde{O}_i / \ell_i

实际上 v1 也是这种结构,但 v2 做了进一步的 IO 优化:把每个 block 的 rescale 因子算法重整理,让它能 fuse 到 GEMM 中。

4.4.3 改进 2:交换循环顺序

v1:外 K 循环、内 Q 循环——不同 Q 块的输出 Oi 在外循环中累加,需要跨外循环迭代保存中间状态

v2:外 Q 循环、内 K 循环——每个 Q 块在一个外循环迭代内完整算完输出,并行单元独立

for i = 1..T_r:                    # Q 块(外)
    for j = 1..T_c:                # K 块(内,仅看 j ≤ i 的因果范围)
        ...
    O_i = \tilde{O}_i / \ell_i
    write O_i

外循环可在 Tr 个 thread block 中并行——新增了序列维并行

4.4.4 改进 3:序列并行

v1 的并行度 = batch_size × n_heads。当 batch=1(推理)或长序列训练时,GPU 可能跑不满(一个 H100 有 132 个 SM,需要至少 132 个并发 thread block)。

v2 的并行度 = batch_size × n_heads × Tr。即使 batch=1,长序列下也有 Tr 个 Q 块独立并行,能填满所有 SM。

4.4.5 改进 4:Warp 协作

v1 把 Kj 分给不同 warp,每个 warp 算 Sij 的一部分,warp 间需要 reduce m,

v2 把 Qi 分给不同 warp,每个 warp 处理 Qi 的一部分行——完全独立,无需 warp 间通信。Kj,Vj 在所有 warp 间共享(broadcast)。

4.4.6 综合效果

A100 (BF16, N=2048,d=64):

实现TFLOPS利用率
PyTorch SDPA7825%
FlashAttention v112440%
FlashAttention v222673%
理论峰值312100%

v2 比 v1 快约 2x,比 PyTorch 标准实现快近 3x。在 LLaMA-2 7B 训练上端到端加速 25-30%。


4.5 FlashAttention v3 (Hopper)

4.5.1 H100 新特性

Hopper (H100) 引入了一系列异步特性:

  1. WGMMA (Warpgroup Matrix Multiply Accumulate):异步张量核指令,4 个 warp 协作发射,FLOPs 翻倍
  2. TMA (Tensor Memory Accelerator):硬件 DMA 引擎,专门搬运 tile,无需 warp 参与
  3. FP8 张量核:E4M3/E5M2 格式,1979 TFLOPS(vs BF16 989 TFLOPS)
  4. 更大 SRAM:每 SM 228 KB

FlashAttention v2 在 H100 上仅达 35% 峰值(BF16 ~700 TFLOPS)——因为它没有用到这些异步能力。

4.5.2 改进 1:Warp 专门化 (Producer-Consumer)

Shah et al. (2024) "FlashAttention-3" 把 warp 分两组:

  • Producer warp:用 TMA 把下一块 Kj,Vj 从 HBM 加载到 SRAM(异步,与计算重叠)
  • Consumer warp:用 WGMMA 算当前块的 SijPijVj

像 CPU 流水线一样,producer 提前预取,consumer 不等待。

4.5.3 改进 2:异步流水

更细致地,consumer 内部也分阶段:

stage 0: 算 S_{ij} = Q_i K_j^T (WGMMA1, 异步发射)
         同时: producer 加载 K_{j+1}, V_{j+1}
stage 1: 等 WGMMA1 完成,算 softmax(S_{ij})
         同时: WGMMA1 已完成,发射 WGMMA2 算 P_{ij} V_j
stage 2: 等 WGMMA2 完成,更新 O_i

GEMM 与 softmax 重叠,softmax 与下一块 GEMM 重叠。

4.5.4 改进 3:FP8 支持

E4M3 用于前向(精度优先),E5M2 用于反向(动态范围优先)。

挑战:

  • FP8 动态范围窄,注意力分数容易溢出
  • softmax 的指数运算对精度敏感

解决:

  • 每块缩放:每个 Kj 块在加载时做一次 scale,记录 scale factor
  • 混合精度累加:FP8 GEMM 的累加器仍是 FP32
  • Hadamard 变换:在 Q、K 上施加随机 Hadamard 矩阵,把数值分布拉平,降低量化误差
python
# 概念示意(非真实代码)
H = hadamard(d)  # Hadamard 矩阵, 元素 ±1/√d
Q_h = Q @ H      # 旋转到 "incoherent" 基
K_h = K @ H      # 同上
# Q K^T = (Q H) (K H)^T = Q K^T (H H^T = I),结果不变
# 但 Q_h, K_h 的元素分布更"均匀",FP8 量化误差小

4.5.5 实测性能

H100 (SXM5, BF16/FP8):

实现BF16 TFLOPSFP8 TFLOPS
FlashAttention v2348-
cuDNN460-
FlashAttention v37401200
理论峰值9891979

v3 BF16 利用率达 75%,FP8 利用率 60%。比 v2 快 2x,比 cuDNN 快 1.6x

DeepSeek-V3 用 FlashAttention v3 + FP8 是首个公开的 FP8 大规模预训练,在 H800 上达成 ~50% MFU(密集训练里很高)。


4.6 FlashAttention 系列对比

版本GPUBF16 利用率关键创新何时发布
StandardA100~25%baseline-
FlashAttention v1A10025-40%tiling + online softmax2022.05
FlashAttention v2A10050-73%交换循环、序列并行、warp 独立2023.07
FlashAttention v3H10075% (BF16), 60% (FP8)warp 专门化、TMA、FP82024.07

显存:所有版本都是 O(N)(vs 标准 O(N2)),这是结构性的优势。


4.7 实战:使用 FlashAttention

4.7.1 PyTorch 内置 SDPA

PyTorch 2.0+ 的 scaled_dot_product_attention 自动选择最优后端:

python
import torch.nn.functional as F

# 自动选择 FlashAttention / Memory-Efficient / Math
out = F.scaled_dot_product_attention(
    q, k, v,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=True,
)

可手动控制后端:

python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

PyTorch 2.5+ 默认在 H100 上用 FlashAttention v2,未来会切到 v3。

4.7.2 flash-attn 库(独立)

Dao 团队的官方 CUDA 实现:

bash
pip install flash-attn --no-build-isolation
python
from flash_attn import flash_attn_func

# q, k, v: [batch, seq_len, n_heads, d_head]
out = flash_attn_func(q, k, v, causal=True, softmax_scale=1.0/math.sqrt(d_head))

支持:

  • flash_attn_varlen_func:变长序列(无 padding)
  • flash_attn_with_kvcache:推理时的 KV-Cache 接口
  • flash_attn_qkvpacked_func:QKV 打包(减少访存)

4.7.3 Triton 实现

OpenAI Triton 提供 FlashAttention 的 Triton 实现,便于研究和定制:

python
# triton/python/tutorials/06-fused-attention.py
import triton
import triton.language as tl

@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
                    start_m, qk_scale, BLOCK_M, BLOCK_N, ...):
    # 内循环:遍历 K, V 块
    for start_n in range(0, ...):
        k = tl.load(K_block_ptr)
        qk = tl.dot(q, k) * qk_scale
        m_ij = tl.maximum(m_i, tl.max(qk, 1))
        p = tl.exp(qk - m_ij[:, None])
        alpha = tl.exp(m_i - m_ij)
        l_i = alpha * l_i + tl.sum(p, 1)
        v = tl.load(V_block_ptr)
        acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
        m_i = m_ij
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
    return acc, l_i, m_i

完整 Triton 实现约 200 行,性能接近 CUDA 版(80-90%)。

4.7.4 何时用何种实现

场景推荐
标准训练 / 推理PyTorch SDPA(自动选择)
极致性能 + 标准 attnflash-attn 库
变长序列(无 padding)flash-attn varlen_func
自定义 attn 变体(ALiBi, sliding window)Triton 实现,自己改 kernel
H100 + FP8FlashAttention v3(cuda 12+)

4.8 局限与展望

4.8.1 不支持的特性

FlashAttention 主要针对标准 scaled dot-product。以下场景需要特殊处理:

  • ALiBi:相对位置偏置加在 S 上,FlashAttention v2.1+ 支持
  • Sliding Window:Mistral 用,FlashAttention 支持 window_size 参数
  • Mixture of Attention:每个 head 用不同 attention 模式,需要分别 kernel
  • Cross-attention:encoder-decoder 中,K/V 来自不同序列,FlashAttention 也支持

4.8.2 极长上下文

序列长度 N 时,即使 FlashAttention,计算仍是 O(N2)。这激发了:

  • Sparse Attention (BigBird, Longformer):稀疏化 attention 矩阵
  • Linear Attention (Performer, Linear Transformer):核技巧近似 softmax,复杂度 O(N)
  • State Space Models (Mamba, S4):完全替换 attention,复杂度 O(N)
  • Ring Attention (Liu 2023):跨 GPU 切分序列,单卡看 N/P 长度,全局通信打通

4.8.3 推理优化的特殊性

Prefill 阶段适合 FlashAttention(长序列、并行)。

Decode 阶段每次只生成 1 个 token,序列维度退化为 1,并行度极低。这时需要特殊 kernel:

  • FlashDecoding:把 K-cache 切成块,多个 SM 协作算单个 query
  • PagedAttention (vLLM):页式 KV-Cache 管理,提高并发
  • Group-Query Attention 友好实现:减少 KV 访存

4.9 IO 复杂度的进一步思考

4.9.1 下界

Dao 2022 证明:任何精确注意力算法的 HBM 访问下界是 Ω(N2d2/M)(在 dM1/2 时)。

FlashAttention 达到这个下界(up to constant),是 IO 最优的。

4.9.2 与 GEMM 的类比

GEMM 的最优 IO 是 Θ(MNK/S)(其中 S 是 cache 大小)。注意力的 IO 下界与之同源——都是"分块矩阵乘"的本质。

理解这个,对设计任何 GPU kernel 都有启发:算术强度低的算子,要靠 tiling + 异步预取榨干带宽

4.9.3 未来方向

  • 跨层融合:attention + FFN 共享 SRAM tile,减少层间 HBM 倒腾
  • Dynamic shape:变长序列动态 tile size,减少 padding 浪费
  • 混合精度优化:FP4 / INT8 等低精度的 attention(仍在研究)
  • ASIC 特化:Cerebras WSE、Groq 等专门硬件,把 attention 视为一等公民

4.10 本章小结

  1. 注意力是 memory-bound:算术强度仅 d/232,远低于 A100 平衡点 208。瓶颈是 N×N 中间矩阵的 HBM 读写。
  2. Online softmax 提供了边算边归一的能力,且块级合并满足结合律——这是 tiling 可行的数学基础。
  3. FlashAttention v1:tiling + online softmax,HBM 访问 Θ(N2d2/M),比标准减少约 M/d2 倍;显存从 O(N2) 降至 O(N)
  4. FlashAttention v2:交换循环(外 Q 内 K)、序列维并行、warp 独立,A100 利用率从 40% 提升到 73%。
  5. FlashAttention v3:Hopper 异步特性(WGMMA、TMA、FP8、warp 专门化),H100 BF16 75%、FP8 60%。
  6. 生产实践:PyTorch 2.5+ SDPA、flash-attn 库、Triton 实现,按需选择。

下一章我们讨论稀疏激活的另一条路线——MoE 架构。


4.11 思考题

  1. IO 下界推导:假设 Q,K,VRN×d,SRAM 大小 MdM。请证明任何精确注意力算法的 HBM 访问量至少为 Ω(N2d2/M)。提示:考虑信息论下界,每个 Sij 需要至少访问 QiKj 各一次。

  2. 块大小设计:A100 的 SRAM 大小 M=192 KB,FP16 下每元素 2 字节。d=64,FlashAttention v2 应取 Br,Bc 多大?分别考虑约束 Brd+Bcd+BrBcM/2(除以 2 留余量给寄存器和栈)。给出至少 2 组合理配置并比较。

  3. FP8 误差分析:FP8 E4M3 的最大表示数 ±448,最小正规数 ±26。设 Q,KN(0,1)d=128,则 Sij=(QKT)ij/d 的分布是什么?该分布在 FP8 下如何缩放才能避免溢出?请定量说明 per-block scaling 的必要性。

基于 MIT 协议发布