上一篇:[[GQA]] 下一篇:[[Train]]
SwiGLU 是一种非线性函数结构,原始公式为:
$$ \text{SwiGLU}(x) = \text{Swish}(xW_1) \odot (xW_2) $$
但 Swish(x) = x × sigmoid(x),所以可以展开为:
$$ \text{SwiGLU}(x) = \left(xW_1 \cdot \sigma(xW_1)\right) \odot (xW_2) $$
这里的流程是这样的:
研究发现,门控线性单元(GLU)能有效捕捉序列中的长程依赖关系,同时避免LSTM和GRU等其他门控机制常见的梯度消失问题。
了解Swish和GLU后可知,SwiGLU是二者的结合体。它采用GLU结构,但用Swish函数(设ß=1)替代了原有的sigmoid激活函数,最终公式如下:
SwiGLU(x) = Swish(W1x+b)⊗(Vx+c)
现在让我们构建一个包含SwiGLU函数的前馈网络。遵循Transformer架构并省略偏置项后:
FFNSwiGLU(x) = (Swish1(xW)⊗xV)W2
在PyTorch中可按如下方式实现:
class SwiGLU(nn.Module):
def __init__(self, w1, w2, w3) -> None:
super().__init__()
self.w1 = w1
self.w2 = w2
self.w3 = w3
def forward(self, x):
x1 = F.linear(x, self.w1.weight)
x2 = F.linear(x, self.w2.weight)
hidden = F.silu(x1) * x2
return F.linear(hidden, self.w3.weight)
此处使用的Silu函数在ß=1时与Swish函数等价。
对比使用SwiGLU与其他GLU变体的Transformer模型表现可见,SwiGLU在预训练阶段始终展现出更优性能。
Swish(包含 sigmoid)是 连续且可导 的(ReLU 在 0 处不可导)。这在训练时对梯度传播更友好,尤其是大模型和长上下文中。
SwiGLU 是 GLU 的一种,带来了「门控机制」:
多个实验(包括 Meta 的 LLaMA 和 Google 的 PaLM)表明,SwiGLU(或其变体如 GeGLU)在相同计算成本下效果更好。特别是在大模型上,ReLU 会造成信息损失,而 SwiGLU 保留更多信号。
| 激活函数 | 计算速度 | 表达能力 | 是否门控 | 性能(大模型) |
|---|---|---|---|---|
| ReLU | ✅ 快 | ❌ 简单 | ❌ 否 | 较差 |
| SwiGLU | 稍慢 | ✅ 强 | ✅ 是 | ✅ 更好 |
上一篇:[[GQA]] 下一篇:[[Train]]