声明,本文主要参考:
minimind官方代码库:
github.com/jingyaogong…
以及b站up主:算法魔法师
www.bilibili.com/video/BV1N9…
以及豆包、千问两款APP进行知识问答补全。
BBPE
强烈建议先观看算法魔法师的视频,了解编码知识(BPE和BBPE分词)
揭秘分词底层技术“BPE”-大部分大模型都在用的分词算法!_哔哩哔哩_bilibili
minimind使用的是BBPE
一句话总结:按照utf-8进行分词,按照字节进行分层,解决中英文的问题
这也解释了,为什么词表为什么会有一堆乱码,起初还以为是我pycharm设置的问题~
minimind使用的是
PreTrainedTokenizerFast
其中比较有意思的
tokenizer.apply_chat_template,解答了,如何将用户常见的 messages输出变成大模型能够理解的格式:
你只需要传入
messages tokenizer.apply_chat_template(messages, tokenize=False) # 内部自动调用 chat_template 拼好格式!
运行结果,因此可以看到,这个其实类似于一个桥接器,将用户的输入转化成一个大模型可分词的
为什么tokenizer.json里面的此表是乱码?因为是使用了BBPE分词,这个不是显示的BUG~
只要直到,tokenizer会将输入进行分词并且将分词后的结果,输出为 id序列即可
Attention
Attention涉及很多细节,我在下文进行补充,可以看左侧目录选择不懂的知识点进行学习
以上红框属于attention的部分
从源码来看,实际上经历了以下的步骤
[输入]
↓
Q投影 → K投影 → V投影:xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
↓
分头塑形 xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
↓
Q/K 归一化 xq, xk = self.q_norm(xq), self.k_norm(xk)
↓
RoPE注入位置信息 xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
↓
KV缓存拼接 xk = torch.cat([past_key_value[0], xk], dim=1)
↓
KV重复匹配Q头 xk = torch.cat([past_key_value[0], xk], dim=1)
↓
注意力打分+掩码 output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv
↓
多头拼接融合 output.transpose(1, 2).reshape(bsz, seq_len, -1)
↓
输出线性变换 self.resid_dropout(self.o_proj(output))
↓
[最终输出+缓存] return output, past_kv
为什么v 不进行归一化?
这里就要理解 attention的原理
核心原因(一句话)
注意力公式里,只有 Q 和 K 会做「点积打分」,V 从来不参与打分! 所以 只有 Q、K 需要归一化稳定数值,V 完全不需要。
看注意力公式就懂了
plaintext
注意力分数 = Softmax( Q · K^T / √d ) · V
Q 和 K → 做点积 → 数值大小会直接影响分数→ 必须归一化,不然分数会爆炸 / 消失
V → 只是被加权求和→ 数值大小不影响分数分布→ 不需要归一化
大白话解释
Q:查询
K:键
V:值
打分过程:Q 去匹配 K → 打出分数 → 用分数去加权 V
Q、K 是 “评委” → 评委打分必须公平、尺度统一 → 必须归一化
V 是 “选手” → 选手本身不需要归一化 → 不用管
attention公式
详细原理可以看UP主“算法魔法师”视频:
关于attention,一定会有很多的疑问:
1、为什么要对kv进行复制?(GQA)
2、什么是旋转位置编码(RoPE,苏剑林大神提出)
3、什么是past_key_value(kv cache)
4、怎么做mask掩码
这些都是之前在原版attention不一样的地方,详细原理见下文
RMSNorm
一种归一化的手段。
参考豆包:
RMSNorm = 均方根归一化,极简版:只做缩放,不学偏移,比 LayerNorm 更快、更简单。
和 LayerNorm 最大区别
- LayerNorm:归一化 + 缩放 weight) + 平移 (β)
- RMSNorm:只归一化 + 缩放 (weight),删掉平移 β
因为没有做平移,只做了归一化 + 缩放,实际上保持了分布的均值
计算更快去掉 LayerNorm 的均值平移,少算一堆运算,训练推理提速。
- 效果不输层归一化大模型场景下,删掉偏置 β 不影响精度,LLaMA 全系都弃 LN 用它。
- 训练更稳、梯度更顺滑只做尺度缩放,不偏移数据分布,缓解梯度消失。
- 适配注意力更友好你代码里只用在Q、K上:统一 QK 向量幅值,注意力打分更均衡,不容易出现极端权重。
和 LayerNorm 对比
-
LayerNorm:减均值 + 除方差 + 缩放 + 偏移
-
RMSNorm:只除均方根 + 缩放少两步,更快、更轻、大模型最优解。
总之:减少参数规模、并且减少复杂的非线性变换
pytorch官方自带了:torch.nn.RMSNorm(dim, eps=eps)
PyTorch ≥ 2.1 正式内置nn.RMSNorm大佬的源码里是自己手搓:
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return (self.weight * self.norm(x.float())).type_as(x)
可以看出,缩放是可学习的参数 self.weight = nn.Parameter
GQA
参考:
www.bilibili.com/video/BV1xV…
简单理解: MHA是原始的多头
MQA:解决MHA的KV cache来回搬运问题,多个Q对应一个KV
GQA:MQA就一个kv,表达力不足,搞成分组。
代码里的原理是:
k、v映射可学习的参数本身就少,例如8个Q,分两组,这里的k v 的只有4
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
xq, xk, xv = (xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2))
这里的repeat就是重复而已
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
bs, slen, num_key_value_heads, head_dim = x.shape
if n_rep == 1: return x
return x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
本质上,实际上就是削减了k、v的可学习的参数量。
为什么这里的映射,不加bias? ( bias=False)
nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
结合你提供的代码和 RMSNorm 的原理,Q、K、V 投影层不加 bias 的原因,主要可以归结为以下三点:
- 代码中的
q_norm和k_norm是 Qwen3 等现代模型的标配 你提供的代码中,在 Q 和 K 投影后显式加入了RMSNorm,这其实是 Qwen3 架构相对于 Qwen2.5 的一个重要升级(称为 QK 预归一化)。
- Qwen3 的做法:不仅 Q 和 K 后面加了
RMSNorm,而且 QKV 投影层的bias也是通过配置项config.attention_bias来灵活控制的(默认通常设为False)。 - Qwen2.5 的做法:Qwen2.5 的 QKV 投影层是直接把
bias=True写死在代码里的,并且没有 QK 预归一化层。 你的这段代码明显采用了更先进的 Qwen3 架构思路。在这种架构下,因为已经有了专门的归一化层来稳定 Q 和 K 的数值分布,投影层的偏置就显得更加多余了。
- 工程上的极致优化:能省则省 在大语言模型(LLM)中,参数量动辄几十亿甚至上千亿。虽然单个偏置向量(大小等于隐藏层维度)看起来不大,但每一层、每一个注意力头累积起来,去掉这些冗余的
bias也能:
- 减少显存占用:模型权重文件更小,加载和部署更省资源。
- 提升推理速度:矩阵乘法(GEMM)后少了一次向量加法运算。在大规模并发推理时,这种微小的优化能带来可观的吞吐量提升。
flash_attn
具体的原理参考豆老师的讲解,实际上使用只需要调用官方API
把 Q/K/V 切成小块(tile),只在高速片上内存(SRAM)里算小块注意力,不存完整 N×N 矩阵,用在线 softmax 渐进合并结果。 NVIDIA
两大核心技术
(1)分块计算(Tiling)NVIDIA
把 Q, K, V 沿序列维度切成小方块(刚好放进 SRAM)
外层循环:加载 K, V 块到 SRAM
内层循环:加载 Q 块,在 SRAM 内算
全程不写 N×N 矩阵到 HBM
(2)在线安全 softmax(Online Softmax)
softmax 依赖整行最大值与指数和,Flash Attention 用增量方式:
每块记录:行最大值 m、指数和 l
块间合并:重新缩放、更新最大值与和
结果完全精确、无精度损失
(3)反向:重计算(Recomputation)
不存巨大的 S/P 矩阵
反向时重新计算小块 S/P → 用 “多算一点” 换 “省巨量显存”
三、效果(一句话总结)
显存:从 O (N²) → O(N) (序列长 32K 从 4GB → 几十 MB)
速度:2–4 倍加速NVIDIA
精度:完全精确(不是近似)
支持更长上下文:GPT-4 128K、LLaMA 3 128K 都靠它
在代码层面
if self.flash and (seq_len > 1) and (not self.is_causal or past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=self.is_causal) # 1,8, 22, 96
这两段代码的效果是等价的
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) # 初始:1,8,22,22,后面就是1,8,1,23
if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv # 1,8,22,96
mask掩码
mask矩阵的本质是遮掩信息,不让模型看到未来的信息和无效的padding信息。
attention公式
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) 实现的是
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores 是什么?
形状:(batch, n_heads, seq_len, seq_len)这是Q × K 得到的注意力分数表。
if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
其中 因果编码是
torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
生成了一个类似的矩阵
[ [0, -inf, -inf],
[0, 0, -inf],
[0, 0, 0] ]
scores[:, :, :, -seq_len:] += 这个矩阵,将未来的信息,给置为负无穷。
后面softmax,将负无穷压成0。
其中mask矩阵,则是有效的为1,无效的为0;
scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
可以看到,也是同样的原理,将score 遮掩的部分,置为了负无穷。
后面softmax,将负无穷压成0。
KV cache
参考,非常建议观看,讲解的很清楚
www.bilibili.com/video/BV1dU…
本轮 吃的 QKV + 历史的 KV向量,足以计算本轮的输出
在模型进行梯度传播 之前,历史的KV结果都是固定的。
阅读源码时,这里要区分两种场景:1、初始输入场景 2、后续每个字输入场景
1、初始输入场景,初始kv cache,也是首token输出的场景(补充一句,难怪很多网站说 首token输出是 kv cache初始化的过程 ):
xq, xk, xv 的长度都是 bs,sel_len,head,head_dim
。。。。。。
past_kv = (xk, xv) if use_cache else None # 1,seq,headcnt,headdim
。。。。。。
return output, past_kv
外层循环,总共有8层,那么每层的kv cache都存起来了,存到presents
presents = []
for layer, past_key_value in zip(self.layers, past_key_values):
hidden_states, present = layer(
hidden_states,
position_embeddings,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask
)
presents.append(present) # 这里存放kv cache,这样列表永远是8、8代表的attention的层数,里面的k、v是 1,x,4,96,x会不断增长
。。。。
return hidden_states, presents, aux_loss
补充:input_ids = torch.cat([input_ids, next_token], dim=-1) 最后将本轮输出的首token,变成输入
2、首字符
最外层的generate 函数里,
这里的past_len 查看attention第一层的k的长度,也是历史长度
past_len = past_key_values[0][0].shape[1] if past_key_values else 0
# 注意这里的 input_ids[:, past_len:],实际上就是将首token取出了
outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs)
最底层的attention模块,将 当前的kv和历史拼接,长度不断变长(bs,oldseq+1,headcnt,headdim)
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
外层的presents,记录了 新的 (bs,oldseq+1,headcnt,headdim)
presents = []
for layer, past_key_value in zip(self.layers, past_key_values):
hidden_states, present = layer(
hidden_states,
position_embeddings,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask
)
presents.append(present) # 这里存放kv cache,这样列表永远是8、8代表的attention的层数,里面的k、v是 1,x,4,96,x会不断增长
3、后续字符就是同样的道理了
RoPE
参考:www.bilibili.com/video/BV1Fj…
强烈建议看一下up主的视频。
原本的正余弦位置编码是加法,使用旋转矩阵,将第m个词旋转m个角度,第n个词旋转n个角度。
旋转矩阵长这个样子
第m个词和第n个词之间的相对位置是 n -m 个角度
RoPE = 把 768 维向量两两分组 → 384 个 2D 向量
每个 2D 向量按位置旋转不同角度
旋转完拼回 768 维
位置信息就自然融进语义里了!
假设是 bs、seq-len,768维 的输入,
768维度,两两一组,总共384组
不同位置 → 角度不同(核心)
位置 0 → 所有 384 组都旋转 0° 简化版,实际上,每一组的旋转度数也都不一样
位置 1 → 所有 384 组都旋转 5° 简化版,实际上,每一组的旋转度数也都不一样
位置 2 → 所有 384 组都旋转 10° 简化版,实际上,每一组的旋转度数也都不一样
位置 3 → 所有 384 组都旋转 15° 简化版,实际上,每一组的旋转度数也都不一样
这就是:
固定一个位置,旋转全部 的 384 对,每一对的旋转角度略微不同,详情看下面的源码
每组应该旋转多少角度的公式,注意这里的 i 代表组
数学公式(就是 RoPE 标准公式)
��=1����2����
d:单头维度(你这里是 96,不是全局 768!重点)
i:维度组下标(0,1,2… 两两为一组)
10000 是基础底数
阅读一下源码
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
orig_max, factor, beta_fast, beta_slow, attn_factor = (
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
)
if end / orig_max > 1.0:
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
freqs = freqs * (1 - ramp + ramp / factor)
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
return freqs_cos, freqs_sin
首先这个就实现了类似的公式
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
只不过这里源码的rope_base不是10000,默认是100w
注意:这里的 freqs 是长度为48的向量,代表每一组的转速不一样。并非上面的例子里,每一个位置的,所有组共用一个旋转角度。
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
实现了计算好,每个位置,旋转多少度。
形状变化: t : [32768] freqs: [48] 外积后 → [32768, 48]
这里的下标,实际上是 序列位置 + 组的位置
接下来,把每一组的旋转角度计算好。
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
变成了 [32768, 96],形成最终查询的大表(以空间换时间)。
注意到没,这里的是拼接,也就是变成了 [c0, c1, c2, ..., c47, c0, c1, c2, ..., c47]的顺序
而不是[c0,c0, c1,c1, c2,c2, ..., c47,c47]
我们希望的是[c0,c0, c1,c1, c2,c2, ..., c47,c47],因为前面的维度每两个为一组,旋转的角度Ci 应该是一样的。
那么这里是有bug嘛?不是的,这里又涉及到一个关键的代码
这里实际上是
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
q_embed = ((q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))).to(q.dtype)
k_embed = ((k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))).to(k.dtype)
return q_embed, k_embed
实际上,我们上面理解的,两两相邻作为一组,也是错的
实际上,是第0个位置 + 第48个位置,凑成一组,间隔48。并非两两相邻一组。
引用豆老师的说法,
单头维度 d=96
预计算出来基础频率角:θ0,θ1,...,θ47(共 48 个)
执行 torch.cat([cosθ, cosθ], dim=-1)得到总长 96 的数组:
�=[cosθ0,cosθ1,...,cosθ47,cosθ0,cosθ1,...,cosθ47$$$$�=[sinθ0,sinθ1,...,sinθ47,sinθ0,sinθ1,...,sinθ47
def rotate_half(x):return torch.cat((-x[..., 48:], x[..., :48]), dim=-1)
设向量
前 48 维: �0∼�47
后 48 维: �0∼�47
执行后: ������_ℎ���(�)=[−�0,−�1,...,−�47, �0,�1,...,�47
旋转公式
����=�⋅�+������_ℎ���(�)⋅�
拆开成对看,只看第一组一对维度:原向量一对:对应位置余弦正弦:\cosθ_0,\ \sinθ_
代入计算:
�0′=�0⋅cosθ0−�0⋅sinθ0 �0′=�0⋅cosθ0+�0⋅sinθ0
完美就是标准二维旋转公式!
为什么拼接 [cos,cos] 刚好能用
拼接后前 48 个cos对应每组第一个元素
拼接后后 48 个cos对应每组第二个元素
rotate_half 天然把每一对维度拆分调换正负
相乘相加后,自动完成 一组维度共用同一个 θ
YaRN
当理解了rope算法之后,上面有一小段代码,包含了yarn算法的内容。
if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
orig_max, factor, beta_fast, beta_slow, attn_factor = (
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
)
if end / orig_max > 1.0:
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
freqs = freqs * (1 - ramp + ramp / factor)
YaRN 的天才想法:
高速维度(近处)几乎不缩,慢速维度(远处)使劲缩
中间维度 → 平滑过渡缩
远近都保住!
说白了,就是插值
引用豆老师的讲解:
固定参数
orig_max = 2048:模型原生支持长度
factor = 16:扩展 16 倍
beta_fast = 32:高频阈值
beta_slow = 1:低频阈值
关键代码:
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
作用:
根据阈值 b,计算 “从第几号频率组开始缩放”
我们代入算一遍:
① 计算 inv_dim (beta_fast=32)
b = 32
分子 = dim * log(orig_max/(b*2π))
分母 = 2 * log(1e6)
结果 ≈ 17.89 → 向下取整 → **17**
② 计算 inv_dim (beta_slow=1.0)
b = 1.0
结果 ≈ 24.57 → 向上取整 → **25**
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
结果:
- low = 17
- high = 25
含义:
- 0 ~ 17 组:不缩放(高频,近处)
- 17 ~ 25 组:平滑过渡
- 25 ~ 47 组:完全缩放(低频,远处)
生成平滑过渡 ramp
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001),0, 1)
torch.arange(48) → [0,1,2,...,47]
计算:
ramp[i] = (i - 17) / (25 - 17) = (i-17)/8
最终 ramp 结果:
- i=0~17 → 0
- i=17~25 → 0 → 1 平滑上升
- i=25~47 → 1
YaRN 核心公式:缩放频率
freqs = freqs * (1 - ramp + ramp / factor)
代入 factor=16:
scale = (1 - ramp) + ramp / 16
最终:
-
ramp=0 → scale=1 → 不缩放
-
ramp=1 → scale=1/16 → 缩 16 倍
-
中间 → 平滑缩放 慢慢开始缩放
最终效果(最关键)
频率组 0~17:不缩放(高速,近处)
频率组 17~25:慢慢开始缩放
频率组 25~47:直接缩小 16 倍(低速,远处)
FeedForward
前馈层,和普通的线性层还不一样
class FeedForward(nn.Module):
def __init__(self, config: MiniMindConfig, intermediate_size: int = None):
super().__init__()
intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
翻译成人话:
x输入:[batch, seq_len, 768]- 经过
gate_proj→ 高维向量 - 经过激活函数 → 变成门控信号
- 门控信号 × up_proj 升维向量→ 相当于 “筛选有用信息”
- 最后
down_proj降维回 768
这就是 SwiGLU / GateFFN,现在大模型标配。
MOEFeedForward
这个会比较麻烦一些。
class MOEFeedForward(nn.Module):
def __init__(self, config: MiniMindConfig):
super().__init__()
self.config = config
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) # # 路由门
self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
batch_size, seq_len, hidden_dim = x.shape
x_flat = x.view(-1, hidden_dim) # # [B*S, 768] # 2. 展平所有Token,方便统一路由
scores = F.softmax(self.gate(x_flat), dim=-1) # [B*S, num_experts]
topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False) # [B*S, num_experts]
if self.config.norm_topk_prob: topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20) # 权重归一化
y = torch.zeros_like(x_flat) # [B*S, 768]
for i, expert in enumerate(self.experts):
mask = (topk_idx == i) # # [B*S, num_experts]
if mask.any():
token_idx = mask.any(dim=-1).nonzero().flatten() # [N个字符选中了,这N个字符在B * S里的index]
weight = topk_weight[mask].view(-1, 1) # 将对应的权重取出,权重的维度是 [N个字符选中了,1]
y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
elif self.training:
y[0, 0] += 0 * sum(p.sum() for p in expert.parameters()) # 为了让 PyTorch 计算图知道:这个专家被使用了!如果一个专家完全没被用到,它的参数就不会出现在计算图里,多轮之后,优化器会报错:
if self.training and self.config.router_aux_loss_coef > 0:
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
else:
self.aux_loss = scores.new_zeros(1).squeeze()
return y.view(batch_size, seq_len, hidden_dim)
其中比较的地方有:门,选择哪个专家
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
专家们:
self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])
1、forward时,
会将输入全部展平:
x_flat = x.view(-1, hidden_dim) # # [BS, 768] # 2. 展平所有Token,方便统一路由*
2、然后计算门分数
scores = F.softmax(self.gate(x_flat), dim=-1) # [BS, num_experts]*
3、然后把需要的TOP专家取出
topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False) # [B*S, num_experts]
4、初始化最终结果
y = torch.zeros_like(x_flat) # [BS, 768]*
5、for循环遍历所有的专家:for i, expert in enumerate(self.experts):
6、找到用到了第i个专家的token
mask = (topk_idx == i) # # [BS, num_experts]*
token_idx = mask.any(dim=-1).nonzero().flatten() # [N个字符选中了,这N个字符在B * S里的index]
注意,这里的向量长度,不是B*S全部了,而是N,N 是选中第i个专家的token
7、并且将对应的权重也捞出来
weight = topk_weight[mask].view(-1, 1) # 将对应的权重取出,权重的维度是 [N个字符选中了,1]
8、使用对应的专家计算的结果,放到目标里。
y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
这段代码是精髓
expert(x_flat[token_idx]) * weight
解释
[B*S, 768] 经过tokenidx 变成了 【N,768】,经过expert专家映射后变成了 N * 768,然后 与weight权重【N,1】点乘,还是N * 768。然后根据index_add_的 token_idx索引,将会这N行向量,加入到对应的y的位置上
截至到上面的步骤,讲完了多专家时怎么预测,但是为了在训练的时候,能够让所有的专家尽可能都被训练到,因此增加了一个
辅助损失函数self.aux_loss
if self.training and self.config.router_aux_loss_coef > 0:
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
这里很有趣
topk_idx 的大小是 【b*s, num_experts】,例如
topk_idx [6,2] 是这样:
[
[1, 3], # token 0
[0, 1], # token 1
[2, 3], # token 2
[1, 0], # token 3
[3, 2], # token 4
[0, 1], # token 5
]
F.one_hot(topk_idx, self.config.num_experts)
得到的结果 shape = [6, 2, 4]
我直接把它画出来给你看:
# 维度含义:[第几个token] [第几个选中专家] [4个专家的one-hot]
# -----------------
# token 0: 选了 [1,3]
# -----------------
[
[0, 1, 0, 0], # 选专家1 → one-hot
[0, 0, 0, 1] # 选专家3 → one-hot
],
.mean(0) 变成了 [ 2, 4]
内容大概是:
[
[0.33, 0.50, 0.16, 0.16], # 第1个选择位的平均
[0.16, 0.33, 0.16, 0.50] # 第2个选择位的平均
]
scores.mean(0) 的shape 是 【4】
load是【2,4】
这里可能有一个bug
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
会变成一个2*4的向量,总和为2。
我不确定这里会不会是个bug?应该再除以 self.config.num_experts。这样全局之和为1?
self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
这里是损失函数,是希望load的真实分布和 门输出的score的分布尽可能不同。
如果两个完全一致,则 aux_loss 越大。
两个完全不一致,则 aux_loss 越小。
这个辅助损失函数是为了鼓励越小越好。即门的输出和真实使用的分布,越不同越好。
乘以 self.config.num_experts 是为了均一化损失函数。
比如
(1/8 * 1/8) * 8 = 1/8 = 0.125
专家越多 → 这个值越小
专家 = 16 → 0.0625
专家 = 32 → 0.03125
结果:
aux_loss 会变得超级小,几乎等于 0!
乘以个 self.config.num_experts ,就会让损失函数变成1
温度、重复性惩罚、top_k、top_p、随机化采样
这些也是经常容易混淆的地方,也是我们模型设置的常用的参数
for _ in range(max_new_tokens):
past_len = past_key_values[0][0].shape[1] if past_key_values else 0 # past_keY_values 是层数的列表,比如8.然后每个元素是tuple,(k,v),其中k的是1,seq,4, 96 ,其中每次循环seq会+1,这样下下面的输入永远是最新的
outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs)
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None
logits = outputs.logits[:, -1, :] / temperature
if repetition_penalty != 1.0:
for i in range(input_ids.shape[0]):
seen = torch.unique(input_ids[i])
score = logits[i, seen]
logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty)
if top_k > 0:
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
logits[mask.scatter(1, sorted_indices, mask)] = -float('inf')
温度
其中温度 temperature 可以看到,是
logits = outputs.logits[:, -1, :] / temperature
temperature越大,比如2,则logits都会压缩变小,原本的各个词之间的差距,变小,输出就更随意
temperature越大,比如0.5,则logits就会变的很大,差距变大,输出就会变的固定(选择概率最大的)
重复性惩罚
for i in range(input_ids.shape[0]):
seen = torch.unique(input_ids[i])
score = logits[i, seen]
logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty)
seen = torch.unique(input_ids[i]) 获得已经看过的token id。
注意,这里的seen 的内容,是 id,最大时词表的大小。代表词表里的对应词的id
score = logits[i, seen] 将这些已经看过的词的预估值拿出来
repetition_penalty越大,则越惩罚重复,即鼓励不重复
举个例子:repetition_penalty 假设 为 2
logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty)
则正的数值,削弱2倍,变小,负数 * 2,变得更小。 变小就意味着不会被选中了
top_k
硬性只要最大的K个:
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
torch.topk(logits, top_k)[0] 返回topk的数值,数值是 由大到小
[..., -1, None]] 返回每一行的最后一个值
最后,将小于这个值的,都赋值为负无穷,不要被选中
top_p
if top_p < 1.0:# 1. 把 logits 从大到小排序
sorted_logits, sorted_indices = torch.sort(logits, descending=True)# 2. 转成概率,从大到小累加,超过 p 的扔掉
mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
# 3. 关键小修正:保证至少保留一个词
mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
# 4. 把掩码还原回原来顺序,禁用低概率词
logits[mask.scatter(1, sorted_indices, mask)] = -float('inf')
简单讲,先排序,再累计求和,右移一位(边界处理,右移才能确保一定大于等于),然后还原到对应的遮掩位置,将其置为 负无穷
例如第一个元素是0.6,第二个是0.24。阈值0.8
则只有第一个位置是满足的,其实是第一个和第二个元素满足,要右移一位。
scatter一般用于排序后的还原,将mask的元素,通过 indices 还原到对应的位置上
随机化采样
next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True)
如果开启随机化采样
按照softmax结果的概率值进行采样,返回对应元素的下标
torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
否则
只取最大的
torch.argmax(logits, dim=-1, keepdim=True)
最后给你一个超级直观对比
表格
| 方法 | 含义 | 优点 | 常用值 |
|---|---|---|---|
| Top-K | 保留前 K 个概率最高词 | 简单 | 20~50 |
| Top-P | 保留概率总和≥P 的词 | 智能、灵活 | 0.9 |
| Temperature | 控制随机性 | 创造力 | 0.7 |
| Repetition Penalty | 防止重复 | 不复读 | 1.1 |
结尾:感谢大佬 jingyaogong 开源的代码,让我能够学习这么多
感谢up主:算法魔法师 的讲解,非常清晰,强烈推荐
出错啦! - bilibili.comspace.bilibili.com/3546972082408260?spm_id_from=333.788.upinfo.head.click
最后,文章有一些硬核,写的也比较潦草(从笔记里改了改),如果有疑问可以留言。