上一篇:[[KV Cache]] 下一篇:[[RoPE]]
LLaMA 3 相对于传统 Transformer 架构在你提到的两点上确实有显著优化:
使用 Flash Attention 来加速注意力机制的计算;
使用 RoPE(Rotary Positional Embedding)来替代传统的绝对位置编码(Absolute Positional Encoding)。
接下来我们分别详细从数学原理和数值例子上解释这两点差异:
原始 Transformer 中的注意力计算是这样的:
给定输入矩阵 $X \in \mathbb{R}^{n \times d}$,其中 n 是序列长度,d 是嵌入维度:
计算 Query、Key、Value:
$Q = XW^Q,\quad K = XW^K,\quad V = XW^V$
计算 Attention:
$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$
这个计算中涉及了整个 $QK^T$ 矩阵的显式构建,内存消耗为 $O(n^2)$,在长序列上尤其低效。
Flash Attention 是一种内存优化和高效并行的计算方式,它的主要思路是:
避免显式构建整个 $QK^T$ 矩阵
在 block-wise 方式下计算 softmax 的同时,保留数值稳定性
融合 kernel 操作,避免重复读取数据
其核心算法是使用类似以下的流程(简化版):
将序列分为多个块(blocks)
对每个 block:
读取对应的 Q、K、V 子块
累积 softmax 的分子、分母部分(保持数值稳定)
累积最终的 attention 输出
公式层面近似如下:
对于每个 $block_i$:
$\text{output}_i = \sum_j \text{softmax}\left(\frac{Q_i K_j^T}{\sqrt{d_k}}\right) V_j$
但计算中不显式构建完整 $QK^T$,而是在计算 softmax 的过程中在线归一化(保持数值稳定)和逐步聚合。
假设:
输入序列长度 n=4096,
维度 d=128
那么传统 attention 需要显式存储 $QK^T \in \mathbb{R}^{4096 \times 4096}$,光这个矩阵就占:
$4096^2 \times 4 \text{ bytes} \approx 64 \text{ MB(float32)}$
而 Flash Attention 的内存复杂度约为 线性 $O(n)$,大大节省内存并加速。
假设我们有 4 个 token,其输入嵌入维度为 2:
$$ X = \begin{bmatrix} 1 & 0 \ 0 & 1 \ 1 & 1 \ 0.5 & 0.5 \ \end{bmatrix} \in \mathbb{R}^{4 \times 2} $$
我们用单位矩阵作为线性变换(简化):
$$ Q = K = V = X $$
计算 $QK^T$:
\begin{bmatrix} 1 & 0 & 1 & 0.5 \ 0 & 1 & 1 & 0.5 \ 1 & 1 & 2 & 1.0 \ 0.5 & 0.5 & 1.0 & 0.5 \end{bmatrix} $$
接着除以 $\sqrt{2} \approx 1.414$,再对每一行做 softmax 得注意力权重矩阵 $A$,再乘 $V$。
这个过程中完整构造了 $4 \times 4$ 的矩阵,代价 $O(n^2)$。
Flash Attention 不会显式构造整个 $QK^T$,而是按块读取,比如读取每次两行(Token 1-2)去对比 Token 3-4:
$$ \text{Attention}i = \frac{\sum{j=0}^{L-1} e^{\frac{Q_i K_j^T}{\sqrt{d_k}}} V_j}{\sum_{j=0}^{L-1} e^{\frac{Q_i K_j^T}{\sqrt{d_k}}}} $$
结果是一样的,但中间计算不保存整个 $QK^T$ 而是边算边聚合,提升了效率并节省显存,尤其在 $n = 4096+$ 时体现明显。 我们来做一个完整、详细的 Flash Attention 计算演示,对比传统 Attention,在小规模上(4 tokens,dim=2)手动走一遍 Flash Attention 的数值计算过程,包含:
我们取一个简化的输入,序列长度 $n = 4$,维度 $d = 2$:
$$ Q = K = V = \begin{bmatrix} 1.0 & 0.0 \ 0.0 & 1.0 \ 1.0 & 1.0 \ 0.5 & 0.5 \ \end{bmatrix} $$
我们不使用线性投影权重,简化为直接用输入做 Q/K/V(常见于教学)。
缩放因子为:
$$ \frac{1}{\sqrt{d}} = \frac{1}{\sqrt{2}} \approx 0.707 $$
\begin{bmatrix} 1.0 & 0.0 & 1.0 & 0.5 \ 0.0 & 1.0 & 1.0 & 0.5 \ 1.0 & 1.0 & 2.0 & 1.0 \ 0.5 & 0.5 & 1.0 & 0.5 \ \end{bmatrix} $$
$$ QK^T_{\text{scaled}} = QK^T \times 0.707 \Rightarrow \begin{bmatrix} 0.707 & 0.000 & 0.707 & 0.354 \ 0.000 & 0.707 & 0.707 & 0.354 \ 0.707 & 0.707 & 1.414 & 0.707 \ 0.354 & 0.354 & 0.707 & 0.354 \ \end{bmatrix} $$
例如第一行: $$ \text{softmax}([0.707, 0, 0.707, 0.354]) = \frac{e^{x_i}}{\sum e^{x}} \Rightarrow \ $$ $$ e^{0.707} \approx 2.028,\ e^{0} = 1,\ e^{0.354} \approx 1.425 \Rightarrow \ $$ $$ \text{Denominator} = 2.028 + 1 + 2.028 + 1.425 \approx 6.481 \Rightarrow \ $$ $$ \text{Softmax row 1} = [0.313, 0.154, 0.313, 0.220] $$
对所有行做 softmax 后得到 $A \in \mathbb{R}^{4 \times 4}$。
$$ \text{Output} = A \cdot V $$
下面我会用两个清晰的表格来展示 Flash Attention 的逐步计算过程。每个表格表示一个 token 的 attention 输出过程,逐步处理所有 K/V 向量。
我们将演示:
假设 $d = 2$,缩放因子为 $1/\sqrt{2} \approx 0.707$
Q / K / V 相同:
| Token | 向量 |
|---|---|
| Q₀/K₀/V₀ | $[1.0,\ 0.0]$ |
| Q₁/K₁/V₁ | $[0.0,\ 1.0]$ |
| Q₂/K₂/V₂ | $[1.0,\ 1.0]$ |
| Q₃/K₃/V₃ | $[0.5,\ 0.5]$ |
| Step (K/V index) | Dot = Q·K | Scaled = Dot×0.707 | Softmax 分子 $e^{score}$ | 累计分母 | 累计加权值(分子) |
|---|---|---|---|---|---|
| K₀/V₀ | 1.0 | 0.707 | 2.028 | 2.028 | $2.028 \cdot [1.0, 0.0] = [2.028, 0.0]$ |
| K₁/V₁ | 0.0 | 0.000 | 1.000 | 3.028 | $[2.028, 0.0] + 1.0 \cdot [0,1] = [2.028, 1.0]$ |
| K₂/V₂ | 1.0 | 0.707 | 2.028 | 5.056 | $[2.028, 1.0] + 2.028 \cdot [1,1] = [4.056, 3.028]$ |
| K₃/V₃ | 0.5 | 0.354 | 1.425 | 6.481 | $[4.056, 3.028] + 1.425 \cdot [0.5, 0.5] = [4.7685, 3.7405]$ |
$$ \frac{[4.7685,\ 3.7405]}{6.481} \approx [0.736,\ 0.577] $$
| Step (K/V index) | Dot = Q·K | Scaled = Dot×0.707 | Softmax 分子 $e^{score}$ | 累计分母 | 累计加权值(分子) |
|---|---|---|---|---|---|
| K₀/V₀ | 0.0 | 0.000 | 1.000 | 1.000 | $1.0 \cdot [1.0, 0.0] = [1.0, 0.0]$ |
| K₁/V₁ | 1.0 | 0.707 | 2.028 | 3.028 | $[1.0, 0.0] + 2.028 \cdot [0,1] = [1.0, 2.028]$ |
| K₂/V₂ | 1.0 | 0.707 | 2.028 | 5.056 | $[1.0, 2.028] + 2.028 \cdot [1,1] = [3.028, 4.056]$ |
| K₃/V₃ | 0.5 | 0.354 | 1.425 | 6.481 | $[3.028, 4.056] + 1.425 \cdot [0.5, 0.5] = [3.7405, 4.7685]$ |
$$ \frac{[3.7405,\ 4.7685]}{6.481} \approx [0.577,\ 0.736] $$
| Token | Output 向量 |
|---|---|
| 0 | $[0.736,\ 0.577]$ |
| 1 | $[0.577,\ 0.736]$ |
你可以看到:
$$ \text{Attention}i = \frac{\sum{j=0}^{L-1} e^{\frac{Q_i K_j^T}{\sqrt{d_k}}} V_j}{\sum_{j=0}^{L-1} e^{\frac{Q_i K_j^T}{\sqrt{d_k}}}} $$
| 项目 | 传统 Attention | Flash Attention |
|---|---|---|
| 计算方式 | 显式构建 QK^T | 按块计算、融合 softmax |
| 内存复杂度 | $O(n^2)$ | $O(n)$ |
| 示例效果 | 需要完整矩阵 | 无需构建完整矩阵 |
| 上一篇:[[KV Cache]] 下一篇:[[RoPE]] |