KV Cache

上一篇:[[Decoder]] 下一篇:[[Flash Attention]]

KV Cache 指的是 解码器(Decoder)Key (K)Value (V),在 推理(Inference) 时缓存,以提高计算效率。

在 Transformer 结构中:

为什么需要 KV Cache?

自回归(Autoregressive)推理 时:

  1. 普通 Transformer 解码:

  2. 使用 KV Cache 优化:


KV Cache 的作用

KV Cache 主要用于 解码器的两种注意力:

  1. Masked Self-Attention(解码器自身)

  2. Cross-Attention(编码器-解码器)


示意图

不使用 KV Cache

每次解码一个新 token,所有 K/V 都需要重新计算:

$$ Step 1: Attention(Q_1, [K_1], [V_1]) $$ $$ Step 2: Attention(Q_2, [K_1, K_2], [V_1, V_2]) $$ $$ Step 3: Attention(Q_3, [K_1, K_2, K_3], [V_1, V_2, V_3]) $$

计算量不断增加。

使用 KV Cache

只计算新 token 的 K/V,其余的直接复用:

$$ Step 1: Attention(Q_1, [K_1], [V_1]) -> Cache [K_1, V_1] $$ $$ Step 2: Attention(Q_2, [K_1, K_2], [V_1, V_2]) -> Cache [K_1, K_2, V_1, V_2] $$ $$ Step 3: Attention(Q_3, [K_1, K_2, K_3], [V_1, V_2, V_3]) -> Cache [K_1, K_2, K_3, V_1, V_2, V_3] $$

减少计算量,提高推理速度。


总结

在推理阶段(即逐个解码 token 时),每个 DecoderBlock 的前面 seq_len 个 token 的输出是不会改变的,新增的 token 只会在每一层引入新的计算,不会反过来修改已存在的输出。

但这个「不会改变」只在以下前提下成立:

  1. 每次只处理一个新 token(即自回归、单步推理),之前的输出缓存下来;

  2. 注意力机制是 Masked 的,新 token 的 attention 只能看到它之前的;

  3. 前层输出不改变,后层就不会重新计算前面位置的输出

  4. 残差连接不影响前面位置的表示,因为它们是逐 token 的操作


🧠 深入理解每一步:以 DecoderBlock 为例解释 为什么前面 token 的输出不会改变

我们基于如下结构(和你之前的代码思路一致):

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads):
        self.self_attn = MultiHeadAttention(embed_size, heads)
        self.cross_attn = MultiHeadAttention(embed_size, heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size),
        )
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)

    def forward(self, x, encoder_output, self_kv_cache=None):
        # 1. Masked Self-Attention
        _x = self.self_attn(x, x, x, cache=self_kv_cache, mask=True)
        x = self.norm1(x + _x)

        # 2. Cross-Attention with encoder
        _x = self.cross_attn(encoder_output, encoder_output, x)
        x = self.norm2(x + _x)

        # 3. Feed-forward
        _x = self.feed_forward(x)
        x = self.norm3(x + _x)

        return x

🔍 逐步解释为什么旧的 token 输出不会被修改

1️⃣ Masked Self-Attention

旧 token 不受新 token 影响


2️⃣ Cross-Attention

decoder 中旧 token 的 cross-attention 输出不会因新 token 改变


3️⃣ Feed-Forward Network

每个 token 都是独立输入一个小 MLP 的,即:

x_i_out = FFN(x_i)  # 对每个 i 单独做

MLP 层(全连接)不会影响其他位置


4️⃣ LayerNorm 和残差连接

不会跨 token 影响其他位置


🧠 总结一句话:

在 Transformer 解码器中,每个位置的输出都是局部计算的,所以只要你不重新运行前面的 token,它们的值就不会被改变,新 token 的加入只会增加一个新位置的计算,不会干扰之前的结果。


🧪 举个数值例子

假设我们在解码器中已有输出:

x = [x₀, x₁]   => shape: [2, embed_size]

现在我们生成了一个新 token,成为:

x = [x₀, x₁, x₂]  => shape: [3, embed_size]

我们先来看可视化示意图,然后我会从多个角度分析:你说法在哪些情况下可能不成立


🖼️ 解码器中的 Token 推理流程示意图

场景:我们已经生成了前两个 token,现在要生成第三个 token。

    ┌────────────── Decoder Layer 1 ──────────────┐
    │                                             │
    │  x₀ ──────►  unchanged                      │
    │  x₁ ──────►  unchanged                      │
    │  x₂ ──────►  only x₂ gets processed         │
    │                                             │
    └─────────────────────────────────────────────┘
             │
             ▼
    ┌────────────── Decoder Layer 2 ──────────────┐
    │                                             │
    │  x₀ ──────►  cache reused                   │
    │  x₁ ──────►  cache reused                   │
    │  x₂ ──────►  new output                     │
    │                                             │
    └─────────────────────────────────────────────┘
             │
             ▼
           ...

📦 每个位置的计算逻辑:

TokenMasked Self-AttnCross-Attn (Encoder)FFN输出是否会变化
x₀不重新计算(缓存)不重新计算不重新计算❌ 不变
x₁不重新计算(缓存)不重新计算不重新计算❌ 不变
x₂✅ 新计算✅ 新计算✅ 新计算✅ 只影响 x₂

上一篇:[[Decoder]] 下一篇:[[Flash Attention]]