上一篇:[[Decoder]] 下一篇:[[Flash Attention]]
KV Cache 指的是 解码器(Decoder) 的 Key (K) 和 Value (V),在 推理(Inference) 时缓存,以提高计算效率。
在 Transformer 结构中:
Key/Value 是固定的,因为输入序列在解码过程中不会改变。Key/Value 需要随着时间步(timestep)增长,每次解码一个新 token 时,都会更新。在 自回归(Autoregressive)推理 时:
普通 Transformer 解码:
Attention(Q, K, V)K/V 需要重新计算 整个已解码序列,计算量随着 t 增长变大。使用 KV Cache 优化:
K 和 V,只计算 当前新生成的 token 的 K/V。Q,然后直接用 K Cache 和 V Cache 计算注意力。K/V 只涉及 新生成的 token,减少冗余计算,提高推理速度。KV Cache 主要用于 解码器的两种注意力:
Masked Self-Attention(解码器自身)
K 和 V 来自 解码器自身QCross-Attention(编码器-解码器)
K 和 V 来自 编码器K/V 不会变化,可以直接缓存 整个 encoder_output,不需要每次重新计算。每次解码一个新 token,所有 K/V 都需要重新计算:
$$ Step 1: Attention(Q_1, [K_1], [V_1]) $$ $$ Step 2: Attention(Q_2, [K_1, K_2], [V_1, V_2]) $$ $$ Step 3: Attention(Q_3, [K_1, K_2, K_3], [V_1, V_2, V_3]) $$
计算量不断增加。
只计算新 token 的 K/V,其余的直接复用:
$$ Step 1: Attention(Q_1, [K_1], [V_1]) -> Cache [K_1, V_1] $$ $$ Step 2: Attention(Q_2, [K_1, K_2], [V_1, V_2]) -> Cache [K_1, K_2, V_1, V_2] $$ $$ Step 3: Attention(Q_3, [K_1, K_2, K_3], [V_1, V_2, V_3]) -> Cache [K_1, K_2, K_3, V_1, V_2, V_3] $$
减少计算量,提高推理速度。
KV Cache 只用于解码器(Decoder) 的 K 和 VK/VK/V 由编码器输出,直接缓存,不需要更新在推理阶段(即逐个解码 token 时),每个 DecoderBlock 的前面
seq_len个 token 的输出是不会改变的,新增的 token 只会在每一层引入新的计算,不会反过来修改已存在的输出。
但这个「不会改变」只在以下前提下成立:
每次只处理一个新 token(即自回归、单步推理),之前的输出缓存下来;
注意力机制是 Masked 的,新 token 的 attention 只能看到它之前的;
前层输出不改变,后层就不会重新计算前面位置的输出;
残差连接不影响前面位置的表示,因为它们是逐 token 的操作;
DecoderBlock 为例解释 为什么前面 token 的输出不会改变我们基于如下结构(和你之前的代码思路一致):
class DecoderBlock(nn.Module):
def __init__(self, embed_size, heads):
self.self_attn = MultiHeadAttention(embed_size, heads)
self.cross_attn = MultiHeadAttention(embed_size, heads)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, 4 * embed_size),
nn.ReLU(),
nn.Linear(4 * embed_size, embed_size),
)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.norm3 = nn.LayerNorm(embed_size)
def forward(self, x, encoder_output, self_kv_cache=None):
# 1. Masked Self-Attention
_x = self.self_attn(x, x, x, cache=self_kv_cache, mask=True)
x = self.norm1(x + _x)
# 2. Cross-Attention with encoder
_x = self.cross_attn(encoder_output, encoder_output, x)
x = self.norm2(x + _x)
# 3. Feed-forward
_x = self.feed_forward(x)
x = self.norm3(x + _x)
return x
假设你在位置 t 新增一个 token,那么:
之前位置 0 ~ t-1 的 query/key/value 都在 cache 里,不变;
当前位置 t 会计算出新的 Q_t,并将 K_t/V_t 加入缓存;
由于 attention mask 是 只允许看到之前位置,所以:
新的 token Q_t 可能 attend 到 [K_0, ..., K_{t-1}]
但旧的 token Q_i (i < t) 并不会重新计算或 attend 到 K_t(新加的 key)
✅ 旧 token 不受新 token 影响。
输入为 encoder_output(固定) 和 x(decoder 的当前输出)
对位置 t 来说,它会用 Q_t 去 attend 编码器的 K_enc, V_enc;
对位置 0 ~ t-1,它们的 Q 也在缓存中,没必要重新计算;
✅ decoder 中旧 token 的 cross-attention 输出不会因新 token 改变。
每个 token 都是独立输入一个小 MLP 的,即:
x_i_out = FFN(x_i) # 对每个 i 单独做
因此位置 0 ~ t-1 的 token 的 FFN 输出不会因为加入了位置 t 而改变。
即便是在训练时整个序列一起计算,它其实也是并行执行多个 FFN,每个 token 的计算是独立的。
✅ MLP 层(全连接)不会影响其他位置。
x + f(x) 是 element-wise 的;
LayerNorm 是对每个 token 的特征向量归一化;
不会混合位置之间的信息;
只改变当前 token 自己的表示分布;
✅ 不会跨 token 影响其他位置。
在 Transformer 解码器中,每个位置的输出都是局部计算的,所以只要你不重新运行前面的 token,它们的值就不会被改变,新 token 的加入只会增加一个新位置的计算,不会干扰之前的结果。
假设我们在解码器中已有输出:
x = [x₀, x₁] => shape: [2, embed_size]
现在我们生成了一个新 token,成为:
x = [x₀, x₁, x₂] => shape: [3, embed_size]
在推理时我们只会计算 x₂ 的 self-attention(Q₂),不会重新计算 x₀/x₁
旧的 [x₀, x₁] 是已经 cache 下来的,输出固定,不变;
新的输出 x₂ 会进入下一层 decoder block,但 [x₀, x₁] 的输出保持原样传下去。
我们先来看可视化示意图,然后我会从多个角度分析:你说法在哪些情况下可能不成立。
┌────────────── Decoder Layer 1 ──────────────┐
│ │
│ x₀ ──────► unchanged │
│ x₁ ──────► unchanged │
│ x₂ ──────► only x₂ gets processed │
│ │
└─────────────────────────────────────────────┘
│
▼
┌────────────── Decoder Layer 2 ──────────────┐
│ │
│ x₀ ──────► cache reused │
│ x₁ ──────► cache reused │
│ x₂ ──────► new output │
│ │
└─────────────────────────────────────────────┘
│
▼
...
| Token | Masked Self-Attn | Cross-Attn (Encoder) | FFN | 输出是否会变化 |
|---|---|---|---|---|
| x₀ | 不重新计算(缓存) | 不重新计算 | 不重新计算 | ❌ 不变 |
| x₁ | 不重新计算(缓存) | 不重新计算 | 不重新计算 | ❌ 不变 |
| x₂ | ✅ 新计算 | ✅ 新计算 | ✅ 新计算 | ✅ 只影响 x₂ |
上一篇:[[Decoder]] 下一篇:[[Flash Attention]]