SwiGLU

上一篇:[[GQA]] 下一篇:[[Train]]

SwiGLU 激活函数

SwiGLU 是一种非线性函数结构,原始公式为:

$$ \text{SwiGLU}(x) = \text{Swish}(xW_1) \odot (xW_2) $$

但 Swish(x) = x × sigmoid(x),所以可以展开为:

$$ \text{SwiGLU}(x) = \left(xW_1 \cdot \sigma(xW_1)\right) \odot (xW_2) $$

这里的流程是这样的:

  1. 输入 $x \in \mathbb{R}^{d}$(比如长度为 16 的向量)
  2. 权重 $W_1, W_2 \in \mathbb{R}^{d \times d’}$,比如从 16 → 32(d’ 是 FFN 的中间维度)
  3. 做两个前馈变换:
  4. 最后元素乘 $u \odot v$

研究发现,门控线性单元(GLU)能有效捕捉序列中的长程依赖关系,同时避免LSTM和GRU等其他门控机制常见的梯度消失问题。

SwiGLU

了解Swish和GLU后可知,SwiGLU是二者的结合体。它采用GLU结构,但用Swish函数(设ß=1)替代了原有的sigmoid激活函数,最终公式如下:

SwiGLU(x) = Swish(W1x+b)⊗(Vx+c)

现在让我们构建一个包含SwiGLU函数的前馈网络。遵循Transformer架构并省略偏置项后:

FFNSwiGLU(x) = (Swish1(xW)⊗xV)W2

如何实现SwiGLU?

在PyTorch中可按如下方式实现:

class SwiGLU(nn.Module):
    def __init__(self, w1, w2, w3) -> None:
        super().__init__()
        self.w1 = w1
        self.w2 = w2
        self.w3 = w3
    def forward(self, x):
        x1 = F.linear(x, self.w1.weight)
        x2 = F.linear(x, self.w2.weight)
        hidden = F.silu(x1) * x2
        return F.linear(hidden, self.w3.weight)

此处使用的Silu函数在ß=1时与Swish函数等价。

对比使用SwiGLU与其他GLU变体的Transformer模型表现可见,SwiGLU在预训练阶段始终展现出更优性能。

✅ 更平滑

Swish(包含 sigmoid)是 连续且可导 的(ReLU 在 0 处不可导)。这在训练时对梯度传播更友好,尤其是大模型和长上下文中。

✅ 引入门控(GLU)

SwiGLU 是 GLU 的一种,带来了「门控机制」:

✅ 更强的表达能力

多个实验(包括 Meta 的 LLaMA 和 Google 的 PaLM)表明,SwiGLU(或其变体如 GeGLU)在相同计算成本下效果更好。特别是在大模型上,ReLU 会造成信息损失,而 SwiGLU 保留更多信号。


⚖️ 总结对比

激活函数计算速度表达能力是否门控性能(大模型)
ReLU✅ 快❌ 简单❌ 否较差
SwiGLU稍慢✅ 强✅ 是✅ 更好

Norm顺序改进:Pre-Norm

上一篇:[[GQA]] 下一篇:[[Train]]