Decoder

下一篇:[[KV Cache]]

好的,我会重新编写 Decoder 相关代码,确保以下几点:

  1. 交叉注意力(Cross-Attention)query 来自 解码器keyvalue 来自 编码器
  2. 自注意力(Self-Attention) 需要 Mask,确保解码器不会看到未来的信息(防止信息泄露)。
  3. 推理(Inference)时逐步解码,符合自回归(Autoregressive)特性,而训练时仍然可以并行计算。

更新后的 Decoder 代码

import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    """实现多头注意力机制"""

    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size must be divisible by number of heads"

        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, value, key, query, mask=None):
        N = query.shape[0]
        value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]

        # 线性变换并拆分为多个头
        values = self.values(value).view(N, value_len, self.heads, self.head_dim)
        keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)
        queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)

        # 计算 QK^T / sqrt(d_k)
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / (self.head_dim ** 0.5)

        # 应用 Mask,防止信息泄露
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-inf"))

        attention = torch.softmax(energy, dim=-1)

        # 计算加权 Value
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.embed_size)
        return self.fc_out(out)


class DecoderBlock(nn.Module):
    """解码器的单个层"""

    def __init__(self, embed_size, heads, forward_expansion):
        super(DecoderBlock, self).__init__()
        self.self_attention = MultiHeadAttention(embed_size, heads)
        self.cross_attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

    def forward(self, x, encoder_out, src_mask, tgt_mask):
        """Decoder 处理流程:
        1. 自注意力 (Masked Self-Attention)  ->  Norm
        2. 交叉注意力 (Encoder-Decoder Attention) -> Norm
        3. 前馈神经网络 (Feed Forward) -> Norm
        """

        # **Masked Self-Attention**: 只能看见之前生成的 token
        self_attention = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(self_attention + x)

        # **Cross-Attention**: Query 来自解码器,Key/Value 来自编码器
        cross_attention = self.cross_attention(encoder_out, encoder_out, x, src_mask)
        x = self.norm2(cross_attention + x)

        # **Feed Forward Network**
        forward = self.feed_forward(x)
        out = self.norm3(forward + x)

        return out


class Decoder(nn.Module):
    """完整的解码器,包含多个 DecoderBlock"""

    def __init__(self, target_vocab_size, embed_size, num_layers, heads, forward_expansion, max_length):
        super(Decoder, self).__init__()
        self.embed_size = embed_size
        self.word_embedding = nn.Embedding(target_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList([
            DecoderBlock(embed_size, heads, forward_expansion) for _ in range(num_layers)
        ])

        self.fc_out = nn.Linear(embed_size, target_vocab_size)

    def forward(self, x, encoder_out, src_mask, tgt_mask):
        """Decoder 前向传播:
        1. 先进行词嵌入和位置编码
        2. 经过多个 DecoderBlock 处理
        3. 最终输出转换为词表大小
        """
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device)
        x = self.word_embedding(x) + self.position_embedding(positions)

        for layer in self.layers:
            x = layer(x, encoder_out, src_mask, tgt_mask)

        return self.fc_out(x)

关键修改与改进

  1. 修正了交叉注意力的 query, key, value 顺序

  2. Masked Self-Attention

  3. Decoder 逐步解码


总结

现在这个 Decoder 就可以真正和 Encoder 交互了,你可以用它实现完整的 Transformer 了!

下一篇:[[KV Cache]]