上一篇:[[Decoder]] 下一篇:[[Flash Attention]]
KV Cache 指的是 解码器(Decoder) 的 Key (K) 和 Value (V),在 推理(Inference) 时缓存,以提高计算效率。
在 Transformer 结构中:
- 编码器(Encoder) 的
Key/Value是固定的,因为输入序列在解码过程中不会改变。 - 解码器(Decoder) 的
Key/Value需要随着时间步(timestep)增长,每次解码一个新 token 时,都会更新。
为什么需要 KV Cache?
在 自回归(Autoregressive)推理 时:
普通 Transformer 解码:
- 计算 $Q = W_q * x_t,K = W_k * x_t,V = W_v * x_t$`
- 计算注意力:
Attention(Q, K, V) - 由于
K/V需要重新计算 整个已解码序列,计算量随着t增长变大。
使用 KV Cache 优化:
- 缓存之前计算过的
K和V,只计算 当前新生成的 token 的K/V。 - 计算新 token 的
Q,然后直接用K Cache和V Cache计算注意力。 - 这样每一步计算的
K/V只涉及 新生成的 token,减少冗余计算,提高推理速度。
- 缓存之前计算过的
KV Cache 的作用
KV Cache 主要用于 解码器的两种注意力:
Masked Self-Attention(解码器自身)
K和V来自 解码器自身- 需要 缓存已生成的部分,每次只计算当前 token 的
Q
Cross-Attention(编码器-解码器)
K和V来自 编码器- 由于编码器的
K/V不会变化,可以直接缓存 整个encoder_output,不需要每次重新计算。
示意图
不使用 KV Cache
每次解码一个新 token,所有 K/V 都需要重新计算:
$$ Step 1: Attention(Q_1, [K_1], [V_1]) $$ $$ Step 2: Attention(Q_2, [K_1, K_2], [V_1, V_2]) $$ $$ Step 3: Attention(Q_3, [K_1, K_2, K_3], [V_1, V_2, V_3]) $$
计算量不断增加。
使用 KV Cache
只计算新 token 的 K/V,其余的直接复用:
$$ Step 1: Attention(Q_1, [K_1], [V_1]) -> Cache [K_1, V_1] $$ $$ Step 2: Attention(Q_2, [K_1, K_2], [V_1, V_2]) -> Cache [K_1, K_2, V_1, V_2] $$ $$ Step 3: Attention(Q_3, [K_1, K_2, K_3], [V_1, V_2, V_3]) -> Cache [K_1, K_2, K_3, V_1, V_2, V_3] $$
减少计算量,提高推理速度。
总结
KV Cache只用于解码器(Decoder) 的K和V- Masked Self-Attention 使用 KV Cache,每次只计算新 token 的
K/V - Cross-Attention
K/V由编码器输出,直接缓存,不需要更新 - 这样可以 提高推理速度,减少计算冗余
在推理阶段(即逐个解码 token 时),每个 DecoderBlock 的前面
seq_len个 token 的输出是不会改变的,新增的 token 只会在每一层引入新的计算,不会反过来修改已存在的输出。
但这个「不会改变」只在以下前提下成立:
每次只处理一个新 token(即自回归、单步推理),之前的输出缓存下来;
注意力机制是 Masked 的,新 token 的 attention 只能看到它之前的;
前层输出不改变,后层就不会重新计算前面位置的输出;
残差连接不影响前面位置的表示,因为它们是逐 token 的操作;
🧠 深入理解每一步:以 DecoderBlock 为例解释 为什么前面 token 的输出不会改变
我们基于如下结构(和你之前的代码思路一致):
class DecoderBlock(nn.Module):
def __init__(self, embed_size, heads):
self.self_attn = MultiHeadAttention(embed_size, heads)
self.cross_attn = MultiHeadAttention(embed_size, heads)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, 4 * embed_size),
nn.ReLU(),
nn.Linear(4 * embed_size, embed_size),
)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.norm3 = nn.LayerNorm(embed_size)
def forward(self, x, encoder_output, self_kv_cache=None):
# 1. Masked Self-Attention
_x = self.self_attn(x, x, x, cache=self_kv_cache, mask=True)
x = self.norm1(x + _x)
# 2. Cross-Attention with encoder
_x = self.cross_attn(encoder_output, encoder_output, x)
x = self.norm2(x + _x)
# 3. Feed-forward
_x = self.feed_forward(x)
x = self.norm3(x + _x)
return x
🔍 逐步解释为什么旧的 token 输出不会被修改
1️⃣ Masked Self-Attention
假设你在位置
t新增一个 token,那么:之前位置
0 ~ t-1的query/key/value都在 cache 里,不变;当前位置
t会计算出新的Q_t,并将K_t/V_t加入缓存;由于 attention mask 是 只允许看到之前位置,所以:
新的 token
Q_t可能 attend 到[K_0, ..., K_{t-1}]但旧的 token
Q_i (i < t)并不会重新计算或 attend 到K_t(新加的 key)
✅ 旧 token 不受新 token 影响。
2️⃣ Cross-Attention
输入为
encoder_output(固定) 和x(decoder 的当前输出)对位置
t来说,它会用Q_t去 attend 编码器的K_enc, V_enc;对位置
0 ~ t-1,它们的 Q 也在缓存中,没必要重新计算;
✅ decoder 中旧 token 的 cross-attention 输出不会因新 token 改变。
3️⃣ Feed-Forward Network
每个 token 都是独立输入一个小 MLP 的,即:
x_i_out = FFN(x_i) # 对每个 i 单独做
因此位置
0 ~ t-1的 token 的 FFN 输出不会因为加入了位置t而改变。即便是在训练时整个序列一起计算,它其实也是并行执行多个 FFN,每个 token 的计算是独立的。
✅ MLP 层(全连接)不会影响其他位置。
4️⃣ LayerNorm 和残差连接
x + f(x)是 element-wise 的;LayerNorm是对每个 token 的特征向量归一化;不会混合位置之间的信息;
只改变当前 token 自己的表示分布;
✅ 不会跨 token 影响其他位置。
🧠 总结一句话:
在 Transformer 解码器中,每个位置的输出都是局部计算的,所以只要你不重新运行前面的 token,它们的值就不会被改变,新 token 的加入只会增加一个新位置的计算,不会干扰之前的结果。
🧪 举个数值例子
假设我们在解码器中已有输出:
x = [x₀, x₁] => shape: [2, embed_size]
现在我们生成了一个新 token,成为:
x = [x₀, x₁, x₂] => shape: [3, embed_size]
在推理时我们只会计算
x₂的 self-attention(Q₂),不会重新计算x₀/x₁旧的
[x₀, x₁]是已经 cache 下来的,输出固定,不变;新的输出
x₂会进入下一层 decoder block,但[x₀, x₁]的输出保持原样传下去。
我们先来看可视化示意图,然后我会从多个角度分析:你说法在哪些情况下可能不成立。
🖼️ 解码器中的 Token 推理流程示意图
场景:我们已经生成了前两个 token,现在要生成第三个 token。
┌────────────── Decoder Layer 1 ──────────────┐
│ │
│ x₀ ──────► unchanged │
│ x₁ ──────► unchanged │
│ x₂ ──────► only x₂ gets processed │
│ │
└─────────────────────────────────────────────┘
│
▼
┌────────────── Decoder Layer 2 ──────────────┐
│ │
│ x₀ ──────► cache reused │
│ x₁ ──────► cache reused │
│ x₂ ──────► new output │
│ │
└─────────────────────────────────────────────┘
│
▼
...
📦 每个位置的计算逻辑:
| Token | Masked Self-Attn | Cross-Attn (Encoder) | FFN | 输出是否会变化 |
|---|---|---|---|---|
| x₀ | 不重新计算(缓存) | 不重新计算 | 不重新计算 | ❌ 不变 |
| x₁ | 不重新计算(缓存) | 不重新计算 | 不重新计算 | ❌ 不变 |
| x₂ | ✅ 新计算 | ✅ 新计算 | ✅ 新计算 | ✅ 只影响 x₂ |
上一篇:[[Decoder]] 下一篇:[[Flash Attention]]