GQA

上一篇:[[RoPE]] 下一篇:[[SwiGLU]]

分组查询注意力(Grouped Query Attention, GQA)将查询头分组,共享键和值头以提高效率,相比传统的多头注意力(Multi-Head Attention, MHA)。GQA根据键值头的数量对查询头进行分组,在保持模型性能的同时提高了计算效率。
为了创建一个简单的数值示例,基于文章中的解释生成代码。

一个分组查询注意力(GQA)的简单数值示例用PyTorch实现:

import torch
from einops import rearrange, einsum

# 示例维度
batch_size = 1
seq_len = 256
num_query_heads = 8
num_key_value_heads = 2
head_dim = 64

# 初始化随机张量
query = torch.randn(batch_size, seq_len, num_query_heads, head_dim)  # 形状: (1, 256, 8, 64)
key = torch.randn(batch_size, seq_len, num_key_value_heads, head_dim)  # 形状: (1, 256, 2, 64)
value = torch.randn(batch_size, seq_len, num_key_value_heads, head_dim)  # 形状: (1, 256, 2, 64)

# 实际在使用过程中的k和v会被投影到更小的空间

# 计算头组数量
num_head_groups = query.shape[2] // key.shape[2]
print(f"头组数量: {num_head_groups}")

# b: batch_size = 1
# l: seq_len = 256
# q: num_query_heads = 8
# k: num_key_value_heads = 2
# d: head_dim = 64
# g: num_head_groups (num_query_heads/num_key_value_heads) = 4

# 重排张量以高效计算
query = rearrange(query, "b l q d -> b q l d")  # 形状: (1, 8, 256, 64)
key = rearrange(key, "b l k d -> b k l d")  # 形状: (1, 2, 256, 64)
value = rearrange(value, "b l k d -> b k l d")  # 形状: (1, 2, 256, 64)

# 分组查询 (1, 8, 256, 64) -> (1, 4, 2, 256, 64)
query = rearrange(query, "b (k g) l d -> b g k l d", g=num_head_groups)  # 形状: (1, 4, 2, 256, 64)

# 计算注意力分数 (1, 4, 2, 256, 64)@(1, 2, 256, 64) -> (1, 4, 2, 256, 256) -reduce-> (1, 2, 256, 256)
scores = einsum(query, key, "b g k l d, b k l d -> b k l l")  # 形状: (1, 2, 256, 256)
print(f"注意力分数形状: {scores.shape}")

# 最终注意力输出
scale = query.size(-1) ** 0.5
attention = torch.nn.functional.softmax(scores / scale, dim=-1)

# 与值张量相乘 (1, 2, 256, 256)@(1, 2, 256, 64) -> (1, 2, 256, 64)
out = einsum(attention, value, "b k l l, b k l d -> b k l d")  # 形状: (1, 2, 256, 64)

# 重排回原始形状
out = rearrange(out, "b k l d -> b l k d")  # 形状: (1, 256, 2, 64)
print(f"最终输出形状: {out.shape}")

可以看到scores张量与值张量形状相同。但实际发生了什么?

在幕后,einsum做了两件事:

  1. querykey进行矩阵乘法。在我们的示例中,这些张量的形状分别为(1, 4, 2, 256, 64)和(1, 2, 256, 64),因此沿最后两个维度的矩阵乘法得到(1, 4, 2, 256, 256)。
  2. 现在需要沿第二个维度(维度g)求和——如果输出形状中省略了该维度,einsum会自动完成求和操作。这种求和是为了匹配键和值张量的头数量。

输入x: (1,256,512) # 假设 embed_dim=512=8*64 -> 通过Wq投影: (1,256,8,64) -> 通过Wk/Wv投影: (1,256,2,64) # 这里key/value的头数较少 -> 计算注意力… -> 输出: (1,256,2,64) = (1,256,128) -> 通过输出投影层: (1,256,512) # 回到原始维度用于残差连接

上一篇:[[RoPE]] 下一篇:[[SwiGLU]]