Llama代码解读注释 - v1.0

70 阅读2分钟

Llama流程图.svg

1. LlamaRMSNorm

LlamaRMSNorm 是 方根均值归一化层

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size,eps=1e-6):
        """
        LlamaRMSNorm 是 方根均值归一化层
        输入参数:
        hidden_size: 中间层的神经元个数, eg: 实例化中的 1024
        eps: 一个极小的实数,eg:1e-6 ; 作用:防止分母除以 0 
        """
        super().__init__()
        # 可学习的权重参数 [1,1,1,.....,1] --> shape: torch.Size([1024])
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
​
    def forward(self, hidden_states):
        # 统一矩阵元素的数据类型,方便进行四则运算
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        # 每个样本在最后一个维度上的平方,并求平均。这将得到每个样本对应的方差
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        '''
        计算输入张量hidden_states标准化的值, 具体如下:
        将输入张量hidden_states 除以 标准差 ,防止分母除以 0 ,我们对方差开方之前 加上了 1e-6,即 self.variance_epsilon
        '''
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        '''
        将标准化的张量hidden_states 乘以 可学习的权重参数 self.weight 得到最终的LlamaRMSNorm 层的输出.
        '''
        return self.weight * hidden_states.to(input_dtype)
    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"# 实例化
if __name__ == '__main__':
    net = LlamaRMSNorm(1024)
    output = net(torch.randn(4,30,1024))

torch.rsqrt()

  • 用法:对每个元素去平方根后再取倒数,公式如下

    outputi=1inputioutput_i = \frac{1}{\sqrt{input_i}}

该函数是元素的操作,并不会影响维度的大小

  • 一个具体案例:
import torch
​
input_x = torch.Tensor([[1,4],[16,25]])
output_y = torch.rsqrt(input_x)
print(output_y)
​
# 输出
'''
tensor([[1.0000, 0.5000],
        [0.2500, 0.2000]])
'''

torch.pow()

  • 用法:torch.pow(base,exp),对元素的求幂运算,公式如下:

    outputi=baseexpioutput_i = base ^{exp_i}

    该函数是元素的操作,并不会影响维度的大小

torch.mean()

  • 用法: torch.mean(input,dim,keepdim)input.mean(dim)是一样的。针对给定的Tensor类型矩阵input的维度(dim)求平均值

    • keepdim 张量是否保留具有的维度,默认值为 False
  • 举例

    import torch
    ​
    a = torch.Tensor([0,1,2,3,4,5]).view(2,3,1) # a.shape --> [2,3,1]
    mean_0 = torch.mean(a,0)  # mean_0.shape --> [3,1]
    mean_1 = torch.mean(a,1)  # mean_1.shape --> [2,1]
    mean_2 = torch.mean(a,2)  # mean_2.shape --> [2,3]# keepdim = Ture
    mean_0 = torch.mean(a,0,keepdim=Ture)  # mean_0.shape --> [1,3,1]
    mean_1 = torch.mean(a,1,keepdim=Ture)  # mean_1.shape --> [2,1,1]
    mean_2 = torch.mean(a,2,keepdim=Ture)  # mean_2.shape --> [2,3,1]