前言
这两年大模型是相当的火,笔者一直在从事大模型相关的工作,但更偏向于用模型:调整prompt,调用不同的模型,看效果是否达到业务的期望。对于笔者而言,模型就像一个黑盒一样,在不了解原理的情况下用起来总是不踏实,因此这篇文章旨在对llama3.1的模型源码进行一个解析,尝试将这个黑盒打开。
llama3.1模型源码地址:github.com/meta-llama/…
整个模型源码文件只有325行,不得不让人感叹,只有325行搭建起的一个模型,训练出了在评测分数上和gpt4o相近的llama3.1-405b。
预备知识
- 了解一些python的语法。
- 线性代数,了解基本的矩阵乘法即可。
- 了解一些torch的常用功能
首先会简单介绍下torch中一些功能,这些是后续理解模型的基础,有了这个基础后,再开始介绍模型的架构。
torch.nn.Linear
使用Linear可以实现矩阵的乘法,下面是一个示例
下面这个例子定义了一个2*3的矩阵matrix_2_3.weight,
假设我们的输入矩阵维度为 2*2,每个元素都是一个浮点数,
给输入矩阵a乘上matrix_2_3.weight矩阵的结果如下
torch.nn.Parameter
将矩阵的参数包装成torch.nn.Parameter,作为属性放在torch.nn.Module中,该矩阵的参数就变成了一个可训练的参数。
下述例子列出模型的参数时,只列出了包装在torch.nn.Parameter中的、大小为2x4的矩阵,未列出另一个大小为2x3的、未包装在Parameter中的张量
torch.sqrt
对输入张量的每个元素进行根号的计算,示例如下。
torch.var_mean
针对输入张量,沿着指定的方向,计算方差、均值并返回。
下述例子初始化了一个2*2的张量,dim=0,表示沿着维度0,也就是下述矩阵的列方向,分别计算平均值与方差,keepDim表示沿着计算的那列的维度会坍缩为1,但那列所在的维度不会消失。
torch.ones
初始化一个大小为指定维度,元素值全为1的张量
张量按位相乘
a是一个大小为(2,3)的张量,和另一个大小也为(2,3)的张量按位相乘结果如下
Tensor.view
以另外一种维度的布局来看当前tensor,下述例子将原本维度大小为(2,2,4)的张量按照(2,2,2,2)的方式来看
torch.arange
给定目标值 end,生成0到end间的序列,不包括end,间隔为1。
torch.outer
笛卡尔积,外积
torch.full
按照指定大小创建tensor,并填充值,下面这个例子创建了一个大小为(3,3)的张量,每个元素的值都设为浮点数表示的负无穷。
torch.triu
用于返回一个上三角矩阵,即对角线以上的元素保持不变,对角线以下的元素被设置为 0。返回上三角矩阵的时候,可指明从哪条对角线开始设置0。
softmax
每个分量拿e取个指数,然后除以所有分量拿e取指数后的和,得到分量的归一化表示。
模型架构
RMSNorm
模型中定义的第一个模块为 root mean square normalization,这是一个对向量进行归一化计算的模块。
一、什么是Normalization?
Normalization:规范化或标准化,把输入数据X规范化成在固定区间范围的标准分布。
二、深度学习中为什么要用Normalization?
Normalization 的作用很明显,把数据拉回标准正态分布,因为神经网络的Block大部分都是矩阵运算,一个向量经过矩阵运算后值会越来越大,为了网络的稳定性,我们需要及时把值拉回正态分布。
三、通常有几种Normalization方案?
以下2种:
- batch normalization
- layer normalization
笔者尝试用下图来解释下这2种normalization的区别,首先明确一下这个立方体中3个轴的含义
- batch dimension:批次大小,eg 一次送2条句子进入网络进行训练,batch_size=2
- seq_len:对应了句子经tokenizer token化后序列的长度,eg 你是谁被tokenizer token化后的id为 123,9283,12394,492。就表示你是谁这个句子对应的token序列的seq_len为4。
- feature:句子中的每个token,会embedding为一个向量,转成向量后就可以通过内积来度量token间的关联性大小,feature就用于指代这个向量,向量的大小有3种常见记法:hidden_size/dim/d_model。举个例子,你是谁token化后的第一个token的id为123,假设经过embedding作用后123这个token变成了(1,0,0)这个向量,9283这个token经过embedding作用后变成了(0,1,0)这个向量,两个向量内积为0,则表示两个token毫无关联。
batch normalization
batch normalization是沿着batch、seq_len2个维度对这一批的数据进行标准归一化,归一化的公式如下。
直接看公式还是有点抽象,上代码看下
class BatchNorm(nn.Module):
def __init__(
self,
size: int,
eps: float = 1e-5,
):
"""
Batch Normalization.
Assumes the shape of the input x is (batch, seq_len, d_model)
Args:
size: shape of the feature dimention (i.e. d_model)
eps: For numerical stability. Defaults to 1e-5.
"""
super(BatchNorm, self).__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(size), requires_grad=True)
self.beta = nn.Parameter(torch.ones(size), requires_grad=True)
def forward(self, x):
x_var, x_mean = torch.var_mean(x, dim=[0,1], keepdim=True, correction=0)
x_std = torch.sqrt(x_var + self.eps)
x_norm = (x - x_mean)/ x_std
return self.gamma.unsqueeze(0).unsqueeze(1) * x_norm + self.beta.unsqueeze(0).unsqueeze(1)
上述模块继承了nn.Module后,定义了2个可被训练的参数gamma、beta,对于输入的大小为(batch,seq_len,d_model)的张量,该模块首先会使用torch.var_mean沿着batch,seq_len2个方向计算方差和平均值。返回的x_var,x_mean都是(1,1,d_model)大小的张量,然后计算了方差加上一个极小值的根号结果。x_std为(1,1,d_model)大小,需要说明一下,这里x_var是张量,self.eps是浮点数,张量和浮点数相加时,浮点数会通过broadcast扩充到和张量同样的维度(1,1,d_model),扩充后所有的元素值和self.eps相同,然后再相加。
同理大小为(batch,seq_len,d_model)的x与大小为(1,1,d_model)的x_mean相减时也存在对应的广播。
root mean square normalization
root mean square的计算公式如下
直接看公式还是有点抽象,看下llama3.1里的RMSNorm代码实现
定义很简单,只包含了一个初始化全为1,大小为dim的一维张量作为模型训练的参数。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
RMSNorm在计算这一批次数据的均值时,是按照最后一个维度【dim=-1】方向计算方差和均值的。这个就是和batch Normalization最根本的区别点。
笔者理解下来,layer Normalization是对一个个 token 进行embedding后得到的向量进行归一【对应于上面2个立方体中左侧立方体中标蓝的那一个token向量】,而batch Normalization是针对一批数据在一个特征上进行归一【对应于上面2个立方体中右侧立方体中标蓝的那一面向量分量,每个向量分量都取自这批数据里每个token的embedding向量的第一个分量】。
Attention
注意力模块是整个模型架构中最难理解的一个模块,包含了旋转位置编码,ntk缩放角度频率,query、key、value间的masked self attention计算,key和value的缓存。
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
Attentnion模块在初始化的时候定义了4个线性矩阵【实际上是ColumnParallelLinear,属于继承了Linear的子类,兼容了分布式的计算环境,抛开对于分布式计算的加成,其核心功能还是linear实现的矩阵相乘】,一些常数【key、value的头数量,query的头数量,为了确保多头中的每个头里都包含一份query、key、value时,key和value需要拷贝几份,每个头里向量的维度】,key、value的cache。
在深入代码细节前,先介绍3种attention的计算策略。
通常attention会分为多个头,每个头里采用相同的方式计算attention,key和value在每个头里有以下三种分配策略
- MHA:query:key:value=1:1:1,每个头内都是独一无二的一组query、key、value,当然也就对应了独一无二要训练的矩阵。
- MQA:query:key:value 的关系为所有的query共用一个key、value【实际计算时会进行一个拷贝,将key和value的个数拷贝到和query一样多】。
- GQA:query:key:value的关系为 几个query共用一个key、value,按照下述最右边的示意图,2个query共用1个key、value,那么总共有2个要训练的key激发矩阵、2个要训练的value激发矩阵、4个要训练的query激发矩阵。
考虑到计算的速度、attention的效果,通常会采用第三种方案-grouped-query attention,mqa所有的query共用一个key、value,效果较差,mha的推理时、训练时需要的计算量多,速度慢,而gqa则是均衡计算速度与推理效果的一种选择。
当传入输入的张量后,attention模块首先会进行矩阵相乘,对输入的大小为(batch_size、seq_len、n_dim)的矩阵x,矩阵x乘上大小为(n_dim,number_of_locak_headshead_dim)的query激发矩阵,激发出大小为(batch_size、seq_len、number_of_local_headshead_dim)的query向量,矩阵x乘上大小为(n_dim,number_of_local_kv_headshead_dim)的key激发矩阵,激发出大小为(batch_size、seq_len、number_of_local_kv_headshead_dim)的key向量,value的计算过程同key相同。
然后会通过view的方式,调整矩阵的维度,调整后的维度见view里依次罗列的参数
拿到query、key、value后,会针对query、key添加旋转位置编码,这里和attention is all you need中提到的位置编码处还有些区别。在attention is all you need论文中,位置信息在送往attention头计算前就已经加上了,然后再激发出query、key、value。但在llama3.1 model.py的实现中,只给query和key添加了位置信息,未对value添加位置信息。
位置信息是通过rope(旋转编码)的方式计算的,关于旋转编码,这里直接贴出了苏剑林给出的计算公式,关于推导过程,可参考苏建林的旋转编码博客。
这里对下面的符号做一个解释:(q0,q1,...qd-1)为seq_len中第m个token经embedding向量化后的表示,因为这个token位于整个token序列中第m个位置,所以要给这个token带上位置m相关的信息,通过对不同的分量分别乘上不同的cos、sin值来添加旋转位置编码信息。其中theta是一个经验值。
在看给query、key带上旋转编码的apply_rotary_emb的实现前,需要先理解下面这个函数precompute_freqs_cis,预计算rope里的角度,下面这段代码位置m属于[0,1,2,3....,end-1]计算了对应的m*theta的复数表示。
为了便于读者理解,笔者手写了计算过程中涉及到的变量的值。
细心的读者可能注意到,代码中在带上位置信息时,还有以下2行。
这里涉及到另一块理论知识-ntk理论:ntk相关的理论知识笔者也没有看太懂,不过提出ntk的作者在reddit的论坛中做了一个比喻来帮助读者理解ntk在做的事情。
rope旋转就像时钟一样,通常情况下一个时钟只能测量126060s=43.2ks,如果把秒针的旋转速度放慢4倍,那么就拥有了一个可以测量443.2k=172.8ks的时钟,但是这时你很看清1s的流逝,因为要经过4s,秒针才会移动一个刻度。更好的做法是将秒针的旋转速度保持不变,将分针的旋转速度放慢1.5倍,将时针的旋转速度放慢2倍,这时就拥有了一个可以测量1.52*43.2k=129.6k 秒时间流逝的时钟,同时还可以精确的测量时间的流逝。
其中秒针属于高频旋转的,尽量保持原样,对于低频旋转的分针、时针,可以多缩放一些,来获取总量的提升。
所以这里的核心思想就是:高频保持原样,低频缩放。
在llama3.1的ntk缩放中,对于freqs的缩放策略有3种:
- 高频的不变
- 低频的固定缩小8倍
- 介于高频和低频之间的频率则进行平滑处理。
明确了precompute_freq_cis是可能被ntk缩放过的mtheta后【会被传到apply_rotary_emb函数的freqs_cis变量】,再来看下apply_rotary_emb的实现。
xq首先会进行维度的调整,由(batch_size,seq_len,num_of_local_heads,head_dim)调整为(batch_size,seq_len,num_of_local_heads,head_dim/2,2)。
torch.view_as_complex则会将最后一个维度的两个分量分别视为复数的实部和虚部。
而freqs_cis的维度会由(seq_len,head_dim/2)调整为(1,seq_len,1,head_dim/2),因为后续要给xq带上位置信息,因此需要将位置信息的维度和xq对齐。
计算详情同苏剑林给出的公式相符,过程中计算涉及到的变量示意如下。
经过flattern(3)后,表示从第三个维度开始聚合维度,第三个维度为head_dim/2,第4个维度为2,经过flattern后,第4个维度消失,第三个维度变为head_dim。
至此,已按照token所在的位置将旋转编码过的位置信息加了上去。
接下来就是计算masked self attention了,其中关于 key、value的移动、缓存、拷贝,逻辑比较简单,这里主要围绕masked self attention进行一个解释。
先看一下下面这个问题:
假设我们有一个维度为(5,3)的矩阵,
其值如下
我们的目标是计算一个新的矩阵,新的矩阵中每一行的分量为前几列分量的平均值
即:
1. 目标矩阵第一行的第i个分量对应于前一行对应分量和的平均值
2. 目标矩阵第二行的第i个分量对应于前2行对应分量和的平均值
3. 目标矩阵第三行的第i个分量对应于前3行对应分量和的平均值
可以通过下述矩阵相乘的方式,来实现上述目标
可以看到,我们构造了一个b矩阵,第一行的值只有第一个元素为1,第二行的值前2个元素分别为0.5,其余元素为0,以此类推。
b矩阵也可以通过下述方式获得。
首先通过torch.full初始化一个5*5的矩阵,所有的元素都是-inf对应浮点值,然后通过triu只保留右对角线上侧的这些元素的值,其余的元素被改成0,
在计算softmax时,由于e的负无穷趋近于0,所以softmax后,-inf分量归一化后的值为0。
然后沿着最后一个维度计算一下softmax,就得到了前面我们手动初始化的b矩阵。
如果把上面的a矩阵视为我们最终的value向量,其中每行代表每个token的向量化表示,b矩阵视为query和key的关联矩阵【一个退化的、特殊的、求平均值的关联关系矩阵】,那么b*a就是在对a进行masked self attention计算。
通常情况下,b矩阵不是均值矩阵,因为不同的query和key的关联程度不同。
通过下图对关联关系矩阵进行一个说明,记维度为(batch_size,seqlen,hidden_size)的输入矩阵,经过query、key、value激发矩阵作用后产出了下面的分量,query和key的关联矩阵中的元素 query_i * key_j就代表了,第i个位置的token与第j个位置的token的内积、关联关系、相关程度。
根据query、key、value是否由同一个x经不同的激发矩阵激发出,以及是否使用mask矩阵对关联矩阵进行mask操作,可以将attention分为3种,这3种attention在 attention is all you need论文中给出的图片里都存在。
- cross attention:query分量、key分量、value分量不是经同一个x激发产生的,允许位置为i的query分量看见位于位置i之后的key分量【不使用mask矩阵】,这样的关联矩阵直接进行softmax后,同value矩阵相乘得到的attention,属于cross attention。
- self attention:query分量、key分量、value分量由同一个x激发产生的,允许位置为i的query分量看见位于位置i之后的key分量【不使用mask矩阵】,这样的关联矩阵直接进行softmax后,同value矩阵相乘得到的attention,属于self attention。
- masked self attention:query分量、key分量、value分量由同一个x激发产生的,不允许位置为i的query分量看见位于位置i之后的key分量【通过mask矩阵实现】,对mask过后的关联矩阵进行softmax后,同value矩阵相乘得到的attention就属于masked self attention。
当然transformer最关心的其实是上述关联矩阵的最后一行对value进行的加权计算,它用最后一个词所激发的query和句子中所有其他词激发的key相乘得到一个关联关系,然后用这个关联关系来加权value。
有了上述基础后,就不难理解llama3.1 model.py中的注意力计算机制了。
传入的mask如下,除了上面提到的2个操作外,多了一个垂直方向堆0的操作
堆1个0的示意结果如下
笔者理解,这么计算的mask矩阵,就可以实现对seqlen窗口内的元素按masked self attention计算关联关系矩阵,对seqlen窗口之前的元素按self attention计算关联关系矩阵。
现在就可以看一下attention中剩余的注意力计算代码了,首先对输入的张量xq进行了维度调整,将number_of_local_heads所在的维度移到第1维度,keys操作相同,这样就实现了query和keys前两个维度对齐,矩阵乘法只发生在后2个维度,后2个维度最终会计算得到一个大小为(seqlen,cache_len+seqlen)的关联关系矩阵。
至此,整个llama3.1 model.py中最难的内容已经介绍完毕,
Attention模块还剩下的内容就是将加权计算后的value矩阵,调整成(batch_size,seqlen,number_of_local_headshead_dim)的维度,然后再乘上一个大小为(number_of_headshead_dim,dim)的矩阵,得到Attention模块的最终输出结果,同输入大小一致,都为(batch_size,seqlen,dim)的一个张量。
FeedForward
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
整个前反馈层,只定义了3个线性层
其映射逻辑也很简单,
对于输入的大小为(batch_size、seq_len、dim)的输入向量,首先乘上w1矩阵,结果变为(batch_size,seq_len,hidden_dim)维度,然后通过silu函数激活一下,激活函数不改变矩阵的维度,同时将输入向量(batch_size、seq_len、dim)乘上w3矩阵,结果变为(batch_size,seq_len,hidden_dim)维度,两个维度为(batch_size,seq_len,hidden_dim)的向量按位相乘,再乘上w2矩阵,结果又恢复成输入的大小(batch_size、seq_len、dim)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
为什么要使用激活函数?
神经网络中每一层的输入输出都是一个线性求和的过程,下一层的输出只是承接了上一层输入函数的线性变换,所以如果没有激活函数,那么无论你构造的神经网络多么复杂,有多少层,最后的输出都是输入的线性组合,纯粹的线性组合并不能够解决更为复杂的问题。而引入激活函数之后,我们会发现常见的激活函数都是非线性的,因此也会给神经元引入非线性元素,使得神经网络可以逼近其他的任何非线性函数,这样可以使得神经网络应用到更多非线性模型中。
silu激活函数又叫Swish激活函数,其函数曲线如下,对于不同的beta,函数曲线有所不同。
TransformerBlock
有了之前的3个模块 RMSNorm、Attention、FeedWard模块,一个单独的transformer block构建起来就很容易了。
在初始化函数中初始化了1个attention模块,1个feedforward模块,2个归一化模块
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
对于输入的大小为(batch_size、seq_len、dim)维的向量,其计算机制就2行
- 首先归一化一下,然后计算注意力,然后和自身相加,做一个残差链接得到h
- 将h归一化一下,送给feedforward处理后和h做一个残差连接。
就得到了单个transformer block块的输出。
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
为什么要做残差连接?
解决梯度消失和网络退化的问题。
如果使用了残差链接,在反向传播计算loss的梯度时,梯度中有一项直接来自于loss对x的偏导数。如果不使用残差链接就没有这项,就只能通过中间权重矩阵传播回去,只通过中间权重矩阵传播回去的值可能非常小【梯度消失问题】。
通常网络的层数不一定是最优的层数,那么较深的一些层可能就属于冗余层,对于冗余层,如果没有残差链接,next(x)=h(x),我们希望实现的映射关系是h(x)=x【next(x)=x】,如果使用了残差链接,next(x) = Relu(x+h(x)),我们希望实现的映射关系是h(x)=0【next(x)=x】,因此对应的权重矩阵参数只需要偏向于0即可实现h(x)=0,训练起来收敛更快。
Transformer
整个Transformer模型是由
- 一个embedding【理解成输入固定为词表大小的线性层】
- 若干层transformerblock
- 一个归一化模块
- 一个将输出映射到词表大小的线性层。
在初始化transformer的时候,也预先计算了rope中用到的角度的复数形式。
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = VocabParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
params.use_scaled_rope,
)
对于输入的处理逻辑如下:
- 首先对token进行embedding,得到token的向量表示。
- 根据传入的start位置,从初始好的rope复数角度变量中取出seqlen长度的复数角度,为后续在 attention 的模块中给query、keys带上旋转位置编码做准备。
- 初始化一个mask矩阵,为Attention模块中计算masked self attention 做准备。
- 将输入送给第一个transformer block块开始处理,第一块处理完后,输出结果送到第二块,重复直到所有的块都完成处理。
- 对所有transformer block处理完后的最终结果进行归一化。
- 将归一化的结果映射到词表大小,以便后续decode成可读的字符串。
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output
参考资料
- 知乎关于normalization的介绍:zhuanlan.zhihu.com/p/647813604
- batch normalization和layer normalization的代码化解释:afterhoursresearch.hashnode.dev/batch-norma…
- 旋转编码介绍:kexue.fm/archives/82…
- ntk awared rope:www.reddit.com/r/LocalLLaM…
- 深入理解self attention:jalammar.github.io/illustrated…
- karpathy系列——强烈推荐!!!:www.youtube.com/watch?v=kCc…
- 如何理解激活函数:zhuanlan.zhihu.com/p/364620596
- 如何理解残差连接:zhuanlan.zhihu.com/p/449792026