PyTorch 于 2.4 版本原生支持 RMSNorm

651 阅读1分钟

RMSNorm 是一种改进的 Layer Norm,不需要记录方差和均值,只需要一份权重,节省资源。

之前要用上 RMSNorm 都需要自己写一份。

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        output = self._norm(x.float())
        output = output * (1.0 + self.weight.float())
        return output.type_as(x)

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

代码来源于 Google 的 Gemma 模型实现。其中反复出现的 .float() 是保证运算精度为 float32。

把 PyTorch 更新到 2.4 版本,就可以通过 from torch.nn import RMSNorm 调用原生实现了,用起来和上面这个实现差不多。

PyTorch 的 RMSNorm 还可以传入 elementwise_affine=False 来跳过 output = output * (1.0 + self.weight.float()) 这个步骤。

我觉得比较有意思的是 eps 的缺省值。PyTorch 的实现并没有设置 1e-6 这样的缺省值,而是 torch.finfo(x.dtype).eps 取机器精度(在该数据类型下两个 1.0 之间的最小差异)。如此,不仅尽可能减少小了 eps 在运算中的影响,还能根据运算精度选取最合适的 eps。