RMSNorm 是一种改进的 Layer Norm,不需要记录方差和均值,只需要一份权重,节省资源。
之前要用上 RMSNorm 都需要自己写一份。
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def forward(self, x):
output = self._norm(x.float())
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
代码来源于 Google 的 Gemma 模型实现。其中反复出现的
.float()是保证运算精度为 float32。
把 PyTorch 更新到 2.4 版本,就可以通过 from torch.nn import RMSNorm 调用原生实现了,用起来和上面这个实现差不多。
PyTorch 的 RMSNorm 还可以传入 elementwise_affine=False 来跳过 output = output * (1.0 + self.weight.float()) 这个步骤。
我觉得比较有意思的是 eps 的缺省值。PyTorch 的实现并没有设置 1e-6 这样的缺省值,而是 torch.finfo(x.dtype).eps 取机器精度(在该数据类型下两个 1.0 之间的最小差异)。如此,不仅尽可能减少小了 eps 在运算中的影响,还能根据运算精度选取最合适的 eps。