Flash Attention

上一篇:[[KV Cache]] 下一篇:[[RoPE]]

LLaMA 3 相对于传统 Transformer 架构在你提到的两点上确实有显著优化:

  1. 使用 Flash Attention 来加速注意力机制的计算;

  2. 使用 RoPE(Rotary Positional Embedding)来替代传统的绝对位置编码(Absolute Positional Encoding)。

接下来我们分别详细从数学原理和数值例子上解释这两点差异:


一、Flash Attention vs 传统 Attention

传统 Attention 计算方式

原始 Transformer 中的注意力计算是这样的:

给定输入矩阵 $X \in \mathbb{R}^{n \times d}$,其中 n 是序列长度,d 是嵌入维度:

这个计算中涉及了整个 $QK^T$ 矩阵的显式构建,内存消耗为 $O(n^2)$,在长序列上尤其低效。


Flash Attention 的改进

Flash Attention 是一种内存优化和高效并行的计算方式,它的主要思路是:

  1. 避免显式构建整个 $QK^T$ 矩阵

  2. 在 block-wise 方式下计算 softmax 的同时,保留数值稳定性

  3. 融合 kernel 操作,避免重复读取数据

其核心算法是使用类似以下的流程(简化版):

公式层面近似如下

对于每个 $block_i$:

$\text{output}_i = \sum_j \text{softmax}\left(\frac{Q_i K_j^T}{\sqrt{d_k}}\right) V_j$

但计算中不显式构建完整 $QK^T$,而是在计算 softmax 的过程中在线归一化(保持数值稳定)和逐步聚合。


数值对比举例

假设:

那么传统 attention 需要显式存储 $QK^T \in \mathbb{R}^{4096 \times 4096}$,光这个矩阵就占:

$4096^2 \times 4 \text{ bytes} \approx 64 \text{ MB(float32)}$

而 Flash Attention 的内存复杂度约为 线性 $O(n)$,大大节省内存并加速。


一、Flash Attention vs 传统 Attention(序列长度为 4)

准备数据

假设我们有 4 个 token,其输入嵌入维度为 2:

$$ X = \begin{bmatrix} 1 & 0 \ 0 & 1 \ 1 & 1 \ 0.5 & 0.5 \ \end{bmatrix} \in \mathbb{R}^{4 \times 2} $$

我们用单位矩阵作为线性变换(简化):

$$ Q = K = V = X $$

传统 Attention

计算 $QK^T$:

$$ QK^T = X X^T = \begin{bmatrix} 1 & 0 \ 0 & 1 \ 1 & 1 \ 0.5 & 0.5 \ \end{bmatrix} \begin{bmatrix} 1 & 0 & 1 & 0.5 \ 0 & 1 & 1 & 0.5 \end{bmatrix}

\begin{bmatrix} 1 & 0 & 1 & 0.5 \ 0 & 1 & 1 & 0.5 \ 1 & 1 & 2 & 1.0 \ 0.5 & 0.5 & 1.0 & 0.5 \end{bmatrix} $$

接着除以 $\sqrt{2} \approx 1.414$,再对每一行做 softmax 得注意力权重矩阵 $A$,再乘 $V$。

这个过程中完整构造了 $4 \times 4$ 的矩阵,代价 $O(n^2)$。


Flash Attention 简要演示

Flash Attention 不会显式构造整个 $QK^T$,而是按块读取,比如读取每次两行(Token 1-2)去对比 Token 3-4:

$$ \text{Attention}i = \frac{\sum{j=0}^{L-1} e^{\frac{Q_i K_j^T}{\sqrt{d_k}}} V_j}{\sum_{j=0}^{L-1} e^{\frac{Q_i K_j^T}{\sqrt{d_k}}}} $$

结果是一样的,但中间计算不保存整个 $QK^T$ 而是边算边聚合,提升了效率并节省显存,尤其在 $n = 4096+$ 时体现明显。 我们来做一个完整、详细的 Flash Attention 计算演示,对比传统 Attention,在小规模上(4 tokens,dim=2)手动走一遍 Flash Attention 的数值计算过程,包含:


✅ 设定

我们取一个简化的输入,序列长度 $n = 4$,维度 $d = 2$

$$ Q = K = V = \begin{bmatrix} 1.0 & 0.0 \ 0.0 & 1.0 \ 1.0 & 1.0 \ 0.5 & 0.5 \ \end{bmatrix} $$

我们不使用线性投影权重,简化为直接用输入做 Q/K/V(常见于教学)。

缩放因子为:

$$ \frac{1}{\sqrt{d}} = \frac{1}{\sqrt{2}} \approx 0.707 $$


🧠 传统 Attention 计算过程

步骤 1:计算 $QK^T$

$$ QK^T = \begin{bmatrix} 1 & 0 \ 0 & 1 \ 1 & 1 \ 0.5 & 0.5 \ \end{bmatrix} \begin{bmatrix} 1 & 0 & 1 & 0.5 \ 0 & 1 & 1 & 0.5 \ \end{bmatrix}

\begin{bmatrix} 1.0 & 0.0 & 1.0 & 0.5 \ 0.0 & 1.0 & 1.0 & 0.5 \ 1.0 & 1.0 & 2.0 & 1.0 \ 0.5 & 0.5 & 1.0 & 0.5 \ \end{bmatrix} $$

步骤 2:缩放 dot-product

$$ QK^T_{\text{scaled}} = QK^T \times 0.707 \Rightarrow \begin{bmatrix} 0.707 & 0.000 & 0.707 & 0.354 \ 0.000 & 0.707 & 0.707 & 0.354 \ 0.707 & 0.707 & 1.414 & 0.707 \ 0.354 & 0.354 & 0.707 & 0.354 \ \end{bmatrix} $$

步骤 3:对每一行做 softmax

例如第一行: $$ \text{softmax}([0.707, 0, 0.707, 0.354]) = \frac{e^{x_i}}{\sum e^{x}} \Rightarrow \ $$ $$ e^{0.707} \approx 2.028,\ e^{0} = 1,\ e^{0.354} \approx 1.425 \Rightarrow \ $$ $$ \text{Denominator} = 2.028 + 1 + 2.028 + 1.425 \approx 6.481 \Rightarrow \ $$ $$ \text{Softmax row 1} = [0.313, 0.154, 0.313, 0.220] $$

对所有行做 softmax 后得到 $A \in \mathbb{R}^{4 \times 4}$。

步骤 4:最终输出

$$ \text{Output} = A \cdot V $$


下面我会用两个清晰的表格来展示 Flash Attention 的逐步计算过程。每个表格表示一个 token 的 attention 输出过程,逐步处理所有 K/V 向量。

我们将演示:


🔢 设定

假设 $d = 2$,缩放因子为 $1/\sqrt{2} \approx 0.707$

Q / K / V 相同:

Token向量
Q₀/K₀/V₀$[1.0,\ 0.0]$
Q₁/K₁/V₁$[0.0,\ 1.0]$
Q₂/K₂/V₂$[1.0,\ 1.0]$
Q₃/K₃/V₃$[0.5,\ 0.5]$

📊 表格 1:Token 0 的 Flash Attention 计算(Q₀ = [1, 0])

Step (K/V index)Dot = Q·KScaled = Dot×0.707Softmax 分子 $e^{score}$累计分母累计加权值(分子)
K₀/V₀1.00.7072.0282.028$2.028 \cdot [1.0, 0.0] = [2.028, 0.0]$
K₁/V₁0.00.0001.0003.028$[2.028, 0.0] + 1.0 \cdot [0,1] = [2.028, 1.0]$
K₂/V₂1.00.7072.0285.056$[2.028, 1.0] + 2.028 \cdot [1,1] = [4.056, 3.028]$
K₃/V₃0.50.3541.4256.481$[4.056, 3.028] + 1.425 \cdot [0.5, 0.5] = [4.7685, 3.7405]$

🎯 最终输出(Token 0):

$$ \frac{[4.7685,\ 3.7405]}{6.481} \approx [0.736,\ 0.577] $$


📊 表格 2:Token 1 的 Flash Attention 计算(Q₁ = [0, 1])

Step (K/V index)Dot = Q·KScaled = Dot×0.707Softmax 分子 $e^{score}$累计分母累计加权值(分子)
K₀/V₀0.00.0001.0001.000$1.0 \cdot [1.0, 0.0] = [1.0, 0.0]$
K₁/V₁1.00.7072.0283.028$[1.0, 0.0] + 2.028 \cdot [0,1] = [1.0, 2.028]$
K₂/V₂1.00.7072.0285.056$[1.0, 2.028] + 2.028 \cdot [1,1] = [3.028, 4.056]$
K₃/V₃0.50.3541.4256.481$[3.028, 4.056] + 1.425 \cdot [0.5, 0.5] = [3.7405, 4.7685]$

🎯 最终输出(Token 1):

$$ \frac{[3.7405,\ 4.7685]}{6.481} \approx [0.577,\ 0.736] $$


✅ 小结

TokenOutput 向量
0$[0.736,\ 0.577]$
1$[0.577,\ 0.736]$

你可以看到:

$$ \text{Attention}i = \frac{\sum{j=0}^{L-1} e^{\frac{Q_i K_j^T}{\sqrt{d_k}}} V_j}{\sum_{j=0}^{L-1} e^{\frac{Q_i K_j^T}{\sqrt{d_k}}}} $$

项目传统 AttentionFlash Attention
计算方式显式构建 QK^T按块计算、融合 softmax
内存复杂度$O(n^2)$$O(n)$
示例效果需要完整矩阵无需构建完整矩阵
上一篇:[[KV Cache]] 下一篇:[[RoPE]]