下一篇:[[KV Cache]]
好的,我会重新编写 Decoder 相关代码,确保以下几点:
query 来自 解码器,key 和 value 来自 编码器。import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""实现多头注意力机制"""
def __init__(self, embed_size, heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size must be divisible by number of heads"
self.values = nn.Linear(embed_size, embed_size, bias=False)
self.keys = nn.Linear(embed_size, embed_size, bias=False)
self.queries = nn.Linear(embed_size, embed_size, bias=False)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, value, key, query, mask=None):
N = query.shape[0]
value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]
# 线性变换并拆分为多个头
values = self.values(value).view(N, value_len, self.heads, self.head_dim)
keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)
queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)
# 计算 QK^T / sqrt(d_k)
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / (self.head_dim ** 0.5)
# 应用 Mask,防止信息泄露
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-inf"))
attention = torch.softmax(energy, dim=-1)
# 计算加权 Value
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.embed_size)
return self.fc_out(out)
class DecoderBlock(nn.Module):
"""解码器的单个层"""
def __init__(self, embed_size, heads, forward_expansion):
super(DecoderBlock, self).__init__()
self.self_attention = MultiHeadAttention(embed_size, heads)
self.cross_attention = MultiHeadAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.norm3 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size),
)
def forward(self, x, encoder_out, src_mask, tgt_mask):
"""Decoder 处理流程:
1. 自注意力 (Masked Self-Attention) -> Norm
2. 交叉注意力 (Encoder-Decoder Attention) -> Norm
3. 前馈神经网络 (Feed Forward) -> Norm
"""
# **Masked Self-Attention**: 只能看见之前生成的 token
self_attention = self.self_attention(x, x, x, tgt_mask)
x = self.norm1(self_attention + x)
# **Cross-Attention**: Query 来自解码器,Key/Value 来自编码器
cross_attention = self.cross_attention(encoder_out, encoder_out, x, src_mask)
x = self.norm2(cross_attention + x)
# **Feed Forward Network**
forward = self.feed_forward(x)
out = self.norm3(forward + x)
return out
class Decoder(nn.Module):
"""完整的解码器,包含多个 DecoderBlock"""
def __init__(self, target_vocab_size, embed_size, num_layers, heads, forward_expansion, max_length):
super(Decoder, self).__init__()
self.embed_size = embed_size
self.word_embedding = nn.Embedding(target_vocab_size, embed_size)
self.position_embedding = nn.Embedding(max_length, embed_size)
self.layers = nn.ModuleList([
DecoderBlock(embed_size, heads, forward_expansion) for _ in range(num_layers)
])
self.fc_out = nn.Linear(embed_size, target_vocab_size)
def forward(self, x, encoder_out, src_mask, tgt_mask):
"""Decoder 前向传播:
1. 先进行词嵌入和位置编码
2. 经过多个 DecoderBlock 处理
3. 最终输出转换为词表大小
"""
N, seq_length = x.shape
positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device)
x = self.word_embedding(x) + self.position_embedding(positions)
for layer in self.layers:
x = layer(x, encoder_out, src_mask, tgt_mask)
return self.fc_out(x)
修正了交叉注意力的 query, key, value 顺序
query 来自解码器 (x)key 和 value 来自编码器 (encoder_out)Masked Self-Attention
Decoder 逐步解码
在训练时,解码器可以一次性处理整个序列(因为 Mask 机制已经防止未来信息泄露)
在推理时,我们需要一个循环,逐步解码:
for t in range(max_length):
tgt_mask = generate_mask(current_output)
decoder_out = decoder(current_output, encoder_out, src_mask, tgt_mask)
next_token = decoder_out[:, -1, :].argmax(dim=-1).item()
current_output.append(next_token)
if next_token == end_token:
break
这样 Decoder 才是自回归(Autoregressive)的,每个时间步 只能用之前生成的 token 进行预测。
query, key, value 的错误现在这个 Decoder 就可以真正和 Encoder 交互了,你可以用它实现完整的 Transformer 了!
下一篇:[[KV Cache]]