阅读minimind源码学习大模型知识(大模型必备基础知识)

0 阅读23分钟

声明,本文主要参考:

minimind官方代码库:
github.com/jingyaogong…

以及b站up主:算法魔法师
www.bilibili.com/video/BV1N9…

以及豆包、千问两款APP进行知识问答补全。

BBPE

强烈建议先观看算法魔法师的视频,了解编码知识(BPE和BBPE分词)
揭秘分词底层技术“BPE”-大部分大模型都在用的分词算法!_哔哩哔哩_bilibili

minimind使用的是BBPE
一句话总结:按照utf-8进行分词,按照字节进行分层,解决中英文的问题
这也解释了,为什么词表为什么会有一堆乱码,起初还以为是我pycharm设置的问题~

minimind使用的是
PreTrainedTokenizerFast

其中比较有意思的
tokenizer.apply_chat_template,解答了,如何将用户常见的 messages输出变成大模型能够理解的格式:

你只需要传入

messages tokenizer.apply_chat_template(messages, tokenize=False) # 内部自动调用 chat_template 拼好格式!
运行结果,因此可以看到,这个其实类似于一个桥接器,将用户的输入转化成一个大模型可分词的

为什么tokenizer.json里面的此表是乱码?因为是使用了BBPE分词,这个不是显示的BUG~

只要直到,tokenizer会将输入进行分词并且将分词后的结果,输出为 id序列即可


Attention

Attention涉及很多细节,我在下文进行补充,可以看左侧目录选择不懂的知识点进行学习

github.com/jingyaogong…

以上红框属于attention的部分
从源码来看,实际上经历了以下的步骤

[输入]   
 ↓ 
Q投影 → K投影 → V投影:xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
↓ 
分头塑形    xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
↓ 
Q/K 归一化    xq, xk = self.q_norm(xq), self.k_norm(xk)
↓ 
RoPE注入位置信息    xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
↓ 
KV缓存拼接    xk = torch.cat([past_key_value[0], xk], dim=1)
↓ 
KV重复匹配Q头    xk = torch.cat([past_key_value[0], xk], dim=1)
↓ 
注意力打分+掩码    output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv
↓ 
多头拼接融合   output.transpose(1, 2).reshape(bsz, seq_len, -1)  
↓ 
输出线性变换   self.resid_dropout(self.o_proj(output))
↓ 
[最终输出+缓存] return output, past_kv

为什么v 不进行归一化?
这里就要理解 attention的原理

核心原因(一句话)
注意力公式里,只有 Q 和 K 会做「点积打分」,V 从来不参与打分! 所以 只有 Q、K 需要归一化稳定数值,V 完全不需要

看注意力公式就懂了
plaintext
注意力分数 = Softmax( Q · K^T / √d ) · V
Q 和 K → 做点积 → 数值大小会直接影响分数→ 必须归一化,不然分数会爆炸 / 消失
V → 只是被加权求和→ 数值大小不影响分数分布→ 不需要归一化

大白话解释
Q:查询
K:键
V:值
打分过程:Q 去匹配 K → 打出分数 → 用分数去加权 V
Q、K 是 “评委”  → 评委打分必须公平、尺度统一 → 必须归一化
V 是 “选手”  → 选手本身不需要归一化 → 不用管

attention公式

详细原理可以看UP主“算法魔法师”视频:

www.bilibili.com/video/BV1jT…

关于attention,一定会有很多的疑问:
1、为什么要对kv进行复制?(GQA)
2、什么是旋转位置编码(RoPE,苏剑林大神提出)
3、什么是past_key_value(kv cache)
4、怎么做mask掩码
这些都是之前在原版attention不一样的地方,详细原理见下文


RMSNorm

一种归一化的手段。

参考豆包:
RMSNorm = 均方根归一化,极简版:只做缩放不学偏移,比 LayerNorm 更快、更简单。
和 LayerNorm 最大区别

  • LayerNorm:归一化 + 缩放 weight) + 平移 (β)
  • RMSNorm:只归一化 + 缩放 (weight),删掉平移 β

因为没有做平移,只做了归一化 + 缩放,实际上保持了分布的均值

计算更快去掉 LayerNorm 的均值平移,少算一堆运算,训练推理提速。

  1. 效果不输层归一化大模型场景下,删掉偏置 β 不影响精度,LLaMA 全系都弃 LN 用它。
  2. 训练更稳、梯度更顺滑只做尺度缩放,不偏移数据分布,缓解梯度消失
  3. 适配注意力更友好你代码里只用在Q、K上:统一 QK 向量幅值,注意力打分更均衡,不容易出现极端权重。

和 LayerNorm 对比

  • LayerNorm:减均值 + 除方差 + 缩放 + 偏移

  • RMSNorm:只除均方根 + 缩放少两步,更快、更轻、大模型最优解。

    总之:减少参数规模、并且减少复杂的非线性变换
    pytorch官方自带了:torch.nn.RMSNorm(dim, eps=eps)
    PyTorch ≥ 2.1 正式内置 nn.RMSNorm

    大佬的源码里是自己手搓:

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return (self.weight * self.norm(x.float())).type_as(x)

可以看出,缩放是可学习的参数 self.weight = nn.Parameter


GQA

参考:
www.bilibili.com/video/BV1xV…

简单理解: MHA是原始的多头
MQA:解决MHA的KV cache来回搬运问题,多个Q对应一个KV
GQA:MQA就一个kv,表达力不足,搞成分组。

代码里的原理是:

 k、v映射可学习的参数本身就少,例如8个Q,分两组,这里的k v 的只有4

self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)


xq, xk, xv = (xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2))

这里的repeat就是重复而已

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    bs, slen, num_key_value_heads, head_dim = x.shape
    if n_rep == 1: return x
    return x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)

本质上,实际上就是削减了k、v的可学习的参数量。

为什么这里的映射,不加bias? ( bias=False)
nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
结合你提供的代码和 RMSNorm 的原理,Q、K、V 投影层不加 bias 的原因,主要可以归结为以下三点:

  1. 代码中的 q_norm 和 k_norm 是 Qwen3 等现代模型的标配 你提供的代码中,在 Q 和 K 投影后显式加入了 RMSNorm,这其实是 Qwen3 架构相对于 Qwen2.5 的一个重要升级(称为 QK 预归一化)。
  • Qwen3 的做法:不仅 Q 和 K 后面加了 RMSNorm,而且 QKV 投影层的 bias 也是通过配置项 config.attention_bias 来灵活控制的(默认通常设为 False)。
  • Qwen2.5 的做法:Qwen2.5 的 QKV 投影层是直接把 bias=True 写死在代码里的,并且没有 QK 预归一化层。 你的这段代码明显采用了更先进的 Qwen3 架构思路。在这种架构下,因为已经有了专门的归一化层来稳定 Q 和 K 的数值分布,投影层的偏置就显得更加多余了。
  1. 工程上的极致优化:能省则省 在大语言模型(LLM)中,参数量动辄几十亿甚至上千亿。虽然单个偏置向量(大小等于隐藏层维度)看起来不大,但每一层、每一个注意力头累积起来,去掉这些冗余的 bias 也能:
  • 减少显存占用:模型权重文件更小,加载和部署更省资源。
  • 提升推理速度:矩阵乘法(GEMM)后少了一次向量加法运算。在大规模并发推理时,这种微小的优化能带来可观的吞吐量提升。

flash_attn

具体的原理参考豆老师的讲解,实际上使用只需要调用官方API

把 Q/K/V 切成小块(tile),只在高速片上内存(SRAM)里算小块注意力,不存完整 N×N 矩阵,用在线 softmax 渐进合并结果。 NVIDIA
两大核心技术
(1)分块计算(Tiling)NVIDIA
把 Q, K, V 沿序列维度切成小方块(刚好放进 SRAM)
外层循环:加载 K, V 块到 SRAM
内层循环:加载 Q 块,在 SRAM 内算
全程不写 N×N 矩阵到 HBM
(2)在线安全 softmax(Online Softmax)
softmax 依赖整行最大值与指数和,Flash Attention 用增量方式
每块记录:行最大值 m、指数和 l
块间合并:重新缩放、更新最大值与和
结果完全精确、无精度损失
(3)反向:重计算(Recomputation)
不存巨大的 S/P 矩阵
反向时重新计算小块 S/P → 用 “多算一点” 换 “省巨量显存”

三、效果(一句话总结)
显存:从 O (N²) → O(N) (序列长 32K 从 4GB → 几十 MB)
速度:2–4 倍加速NVIDIA
精度完全精确(不是近似)
支持更长上下文:GPT-4 128K、LLaMA 3 128K 都靠它

在代码层面

if self.flash and (seq_len > 1) and (not self.is_causal or past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
    output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=self.is_causal) # 1,8, 22, 96

这两段代码的效果是等价的

scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) # 初始:182222,后面就是1,8,1,23
if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv # 1,8,22,96

mask掩码

mask矩阵的本质是遮掩信息,不让模型看到未来的信息和无效的padding信息。

attention公式

scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) 实现的是

scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)

scores 是什么?
形状:(batch, n_heads, seq_len, seq_len)这是Q × K 得到的注意力分数表

if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9

其中 因果编码是
torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
生成了一个类似的矩阵

[ [0, -inf, -inf],
  [0, 0, -inf],
  [0, 0, 0] ]

scores[:, :, :, -seq_len:] += 这个矩阵,将未来的信息,给置为负无穷。
后面softmax,将负无穷压成0。
其中mask矩阵,则是有效的为1,无效的为0;
scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
可以看到,也是同样的原理,将score 遮掩的部分,置为了负无穷。

后面softmax,将负无穷压成0。


KV cache

参考,非常建议观看,讲解的很清楚
www.bilibili.com/video/BV1dU…

本轮 吃的 QKV + 历史的 KV向量,足以计算本轮的输出
在模型进行梯度传播 之前,历史的KV结果都是固定的。

阅读源码时,这里要区分两种场景:1、初始输入场景 2、后续每个字输入场景

1、初始输入场景,初始kv cache,也是首token输出的场景(补充一句,难怪很多网站说 首token输出是 kv cache初始化的过程 ):

xq, xk, xv 的长度都是 bs,sel_len,head,head_dim
。。。。。。
past_kv = (xk, xv) if use_cache else None  # 1,seq,headcnt,headdim
。。。。。。
return output, past_kv

外层循环,总共有8层,那么每层的kv cache都存起来了,存到presents

presents = []
for layer, past_key_value in zip(self.layers, past_key_values):
    hidden_states, present = layer(
        hidden_states,
        position_embeddings,
        past_key_value=past_key_value,
        use_cache=use_cache,
        attention_mask=attention_mask
    )
    presents.append(present) # 这里存放kv cache,这样列表永远是8、8代表的attention的层数,里面的k、v是 1,x,4,96,x会不断增长
。。。。    
return hidden_states, presents, aux_loss

补充:input_ids = torch.cat([input_ids, next_token], dim=-1) 最后将本轮输出的首token,变成输入

2、首字符

最外层的generate 函数里,

这里的past_len 查看attention第一层的k的长度,也是历史长度
past_len = past_key_values[0][0].shape[1] if past_key_values else 0 

# 注意这里的 input_ids[:, past_len:],实际上就是将首token取出了
outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs)
最底层的attention模块,将 当前的kv和历史拼接,长度不断变长(bs,oldseq+1,headcnt,headdim)
if past_key_value is not None:
    xk = torch.cat([past_key_value[0], xk], dim=1)
    xv = torch.cat([past_key_value[1], xv], dim=1)
    
    
外层的presents,记录了 新的 (bs,oldseq+1,headcnt,headdim)
presents = []
for layer, past_key_value in zip(self.layers, past_key_values):
    hidden_states, present = layer(
        hidden_states,
        position_embeddings,
        past_key_value=past_key_value,
        use_cache=use_cache,
        attention_mask=attention_mask
    )
    presents.append(present) # 这里存放kv cache,这样列表永远是8、8代表的attention的层数,里面的k、v是 1,x,4,96,x会不断增长

3、后续字符就是同样的道理了


RoPE

参考:www.bilibili.com/video/BV1Fj…
强烈建议看一下up主的视频。

原本的正余弦位置编码是加法,使用旋转矩阵,将第m个词旋转m个角度,第n个词旋转n个角度。
旋转矩阵长这个样子

第m个词和第n个词之间的相对位置是 n -m 个角度

RoPE = 把 768 维向量两两分组 → 384 个 2D 向量
每个 2D 向量按位置旋转不同角度
旋转完拼回 768 维
位置信息就自然融进语义里了!

假设是 bs、seq-len,768维 的输入,
768维度,两两一组,总共384组
不同位置 → 角度不同(核心)
位置 0 → 所有 384 组都旋转 0° 简化版,实际上,每一组的旋转度数也都不一样
位置 1 → 所有 384 组都旋转 5° 简化版,实际上,每一组的旋转度数也都不一样
位置 2 → 所有 384 组都旋转 10° 简化版,实际上,每一组的旋转度数也都不一样
位置 3 → 所有 384 组都旋转 15° 简化版,实际上,每一组的旋转度数也都不一样
这就是:
固定一个位置,旋转全部 的 384 对,每一对的旋转角度略微不同,详情看下面的源码

每组应该旋转多少角度的公式,注意这里的 i 代表组

数学公式(就是 RoPE 标准公式)

��=1����2����

d:单头维度(你这里是 96,不是全局 768!重点)
i:维度组下标(0,1,2… 两两为一组)
10000 是基础底数

阅读一下源码

def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
    freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
    if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
        orig_max, factor, beta_fast, beta_slow, attn_factor = (
            rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
            rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
        )
        if end / orig_max > 1.0:
            inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
            low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
            ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
            freqs = freqs * (1 - ramp + ramp / factor)
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
    freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
    return freqs_cos, freqs_sin

首先这个就实现了类似的公式
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0

只不过这里源码的rope_base不是10000,默认是100w
注意:这里的 freqs 是长度为48的向量,代表每一组的转速不一样。并非上面的例子里,每一个位置的,所有组共用一个旋转角度。

t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()

实现了计算好,每个位置,旋转多少度。

形状变化: t : [32768] freqs: [48] 外积后 → [32768, 48]

这里的下标,实际上是 序列位置 + 组的位置

接下来,把每一组的旋转角度计算好。

freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) 
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)

变成了 [32768, 96],形成最终查询的大表(以空间换时间)。

注意到没,这里的是拼接,也就是变成了 [c0, c1, c2, ..., c47, c0, c1, c2, ..., c47]的顺序
而不是[c0,c0, c1,c1, c2,c2, ..., c47,c47]

我们希望的是[c0,c0, c1,c1, c2,c2, ..., c47,c47],因为前面的维度每两个为一组,旋转的角度Ci 应该是一样的。

那么这里是有bug嘛?不是的,这里又涉及到一个关键的代码

这里实际上是

def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
    q_embed = ((q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))).to(q.dtype)
    k_embed = ((k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))).to(k.dtype)
    return q_embed, k_embed

实际上,我们上面理解的,两两相邻作为一组,也是错的
实际上,是第0个位置 + 第48个位置,凑成一组,间隔48。并非两两相邻一组。

引用豆老师的说法,
单头维度 d=96
预计算出来基础频率角:θ0,θ1,...,θ47(共 48 个)
执行 torch.cat([cosθ, cosθ], dim=-1)得到总长 96 的数组:
�=[cos⁡θ0,cos⁡θ1,...,cos⁡θ47,cos⁡θ0,cos⁡θ1,...,cos⁡θ47$$$$�=[sin⁡θ0,sin⁡θ1,...,sin⁡θ47,sin⁡θ0,sin⁡θ1,...,sin⁡θ47

def rotate_half(x):return torch.cat((-x[..., 48:], x[..., :48]), dim=-1)

设向量 =[0,1,2,3,...,47,0,1,2,3,...,47]�=[�0,�1,�2,�3,...,�47,�0,�1,�2,�3,...,�47]
前 48 维: �0∼�47
后 48 维: �0∼�47
执行后: ������_ℎ���(�)=[−�0,−�1,...,−�47, �0,�1,...,�47

旋转公式
����=�⋅�+������_ℎ���(�)⋅�
拆开成对看,只看第一组一对维度:原向量一对:(a0, b0(a_0,\ b_0对应位置余弦正弦:\cosθ_0,\ \sinθ_
代入计算:

�0′=�0⋅cos⁡θ0−�0⋅sin⁡θ0 �0′=�0⋅cos⁡θ0+�0⋅sin⁡θ0
完美就是标准二维旋转公式!

为什么拼接 [cos,cos] 刚好能用
拼接后前 48 个cos对应每组第一个元素
拼接后后 48 个cos对应每组第二个元素
rotate_half 天然把每一对维度拆分调换正负
相乘相加后,自动完成 一组维度共用同一个 θ


YaRN

当理解了rope算法之后,上面有一小段代码,包含了yarn算法的内容。

if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
    orig_max, factor, beta_fast, beta_slow, attn_factor = (
        rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
        rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
    )
    if end / orig_max > 1.0:
        inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
        low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
        ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
        freqs = freqs * (1 - ramp + ramp / factor)

YaRN 的天才想法:
高速维度(近处)几乎不缩,慢速维度(远处)使劲缩
中间维度 → 平滑过渡缩
远近都保住!

说白了,就是插值
引用豆老师的讲解:

固定参数
orig_max = 2048:模型原生支持长度
factor = 16:扩展 16 倍
beta_fast = 32:高频阈值
beta_slow = 1:低频阈值

关键代码:
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
作用:
根据阈值 b,计算 “从第几号频率组开始缩放”
我们代入算一遍:
① 计算 inv_dim (beta_fast=32)

b = 32
分子 = dim * log(orig_max/(b*2π))
分母 = 2 * log(1e6)

结果 ≈ 17.89 → 向下取整 → **17**

② 计算 inv_dim (beta_slow=1.0)

b = 1.0
结果 ≈ 24.57 → 向上取整 → **25**

low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)

结果:

  • low = 17
  • high = 25

含义:

  • 0 ~ 17 组:不缩放(高频,近处)
  • 17 ~ 25 组:平滑过渡
  • 25 ~ 47 组:完全缩放(低频,远处)

生成平滑过渡 ramp

ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001),0, 1)

torch.arange(48) → [0,1,2,...,47]
计算:

ramp[i] = (i - 17) / (25 - 17) = (i-17)/8

最终 ramp 结果:

  • i=0~17 → 0
  • i=17~25 → 0 → 1 平滑上升
  • i=25~47 → 1

YaRN 核心公式:缩放频率

freqs = freqs * (1 - ramp + ramp / factor)

代入 factor=16:

scale = (1 - ramp) + ramp / 16

最终:

  • ramp=0 → scale=1 → 不缩放

  • ramp=1 → scale=1/16 → 缩 16 倍

  • 中间 → 平滑缩放 慢慢开始缩放

最终效果(最关键)
频率组 0~17:不缩放(高速,近处)
频率组 17~25:慢慢开始缩放
频率组 25~47:直接缩小 16 倍(低速,远处)


FeedForward

前馈层,和普通的线性层还不一样

class FeedForward(nn.Module):
    def __init__(self, config: MiniMindConfig, intermediate_size: int = None):
        super().__init__()
        intermediate_size = intermediate_size or config.intermediate_size
        self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

翻译成人话:

  1. x 输入:[batch, seq_len, 768]
  2. 经过 gate_proj → 高维向量
  3. 经过激活函数 → 变成门控信号
  4. 门控信号 × up_proj 升维向量→ 相当于 “筛选有用信息”
  5. 最后 down_proj 降维回 768

这就是 SwiGLU / GateFFN,现在大模型标配。

MOEFeedForward

这个会比较麻烦一些。

class MOEFeedForward(nn.Module):
    def __init__(self, config: MiniMindConfig):
        super().__init__()
        self.config = config
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) # # 路由门
        self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.shape
        x_flat = x.view(-1, hidden_dim) #  # [B*S, 768] # 2. 展平所有Token,方便统一路由
        scores = F.softmax(self.gate(x_flat), dim=-1) # [B*S, num_experts]
        topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False) # [B*S, num_experts]
        if self.config.norm_topk_prob: topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20) # 权重归一化
        y = torch.zeros_like(x_flat) # [B*S, 768]

        for i, expert in enumerate(self.experts):
            mask = (topk_idx == i) #  # [B*S, num_experts]
            if mask.any():
                token_idx = mask.any(dim=-1).nonzero().flatten() # [N个字符选中了,这N个字符在B * S里的index]
                weight = topk_weight[mask].view(-1, 1) # 将对应的权重取出,权重的维度是 [N个字符选中了,1]
                y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
            elif self.training:
                y[0, 0] += 0 * sum(p.sum() for p in expert.parameters()) # 为了让 PyTorch 计算图知道:这个专家被使用了!如果一个专家完全没被用到,它的参数就不会出现在计算图里,多轮之后,优化器会报错:

        if self.training and self.config.router_aux_loss_coef > 0:
            load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
            self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
        else:
            self.aux_loss = scores.new_zeros(1).squeeze()
        return y.view(batch_size, seq_len, hidden_dim)

其中比较的地方有:门,选择哪个专家
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)

专家们:
self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])

1、forward时,
会将输入全部展平:
x_flat = x.view(-1, hidden_dim)  # # [BS, 768] # 2. 展平所有Token,方便统一路由*

2、然后计算门分数
scores = F.softmax(self.gate(x_flat), dim=-1)  # [BS, num_experts]*

3、然后把需要的TOP专家取出
topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False)  # [B*S, num_experts]

4、初始化最终结果
y = torch.zeros_like(x_flat)  # [BS, 768]*

5、for循环遍历所有的专家:for i, expert in enumerate(self.experts):

6、找到用到了第i个专家的token
mask = (topk_idx == i)  # # [BS, num_experts]*
token_idx = mask.any(dim=-1).nonzero().flatten()  # [N个字符选中了,这N个字符在B * S里的index]
注意,这里的向量长度,不是B*S全部了,而是N,N 是选中第i个专家的token

7、并且将对应的权重也捞出来
weight = topk_weight[mask].view(-1, 1)  # 将对应的权重取出,权重的维度是 [N个字符选中了,1]

8、使用对应的专家计算的结果,放到目标里。
y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
这段代码是精髓
expert(x_flat[token_idx]) * weight

解释
[B*S, 768] 经过tokenidx 变成了 【N,768】,经过expert专家映射后变成了 N * 768,然后 与weight权重【N,1】点乘,还是N * 768。然后根据index_add_的 token_idx索引,将会这N行向量,加入到对应的y的位置上

截至到上面的步骤,讲完了多专家时怎么预测,但是为了在训练的时候,能够让所有的专家尽可能都被训练到,因此增加了一个

辅助损失函数self.aux_loss

if self.training and self.config.router_aux_loss_coef > 0:
    load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
    self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef

这里很有趣
topk_idx 的大小是 【b*s, num_experts】,例如

topk_idx [6,2] 是这样:

[
    [1, 3],   # token 0
    [0, 1],   # token 1
    [2, 3],   # token 2
    [1, 0],   # token 3
    [3, 2],   # token 4
    [0, 1],   # token 5
]

F.one_hot(topk_idx, self.config.num_experts)
得到的结果 shape =  [6, 2, 4]
我直接把它画出来给你看:

# 维度含义:[第几个token] [第几个选中专家] [4个专家的one-hot]

# -----------------
# token 0: 选了 [1,3]
# -----------------
[
    [0, 1, 0, 0],  # 选专家1 → one-hot
    [0, 0, 0, 1]   # 选专家3 → one-hot
],

.mean(0) 变成了  [ 2, 4]
内容大概是:

[
    [0.33, 0.50, 0.16, 0.16],  # 第1个选择位的平均
    [0.16, 0.33, 0.16, 0.50]   # 第2个选择位的平均
]

scores.mean(0) 的shape 是 【4】
load是【2,4】

这里可能有一个bug
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
会变成一个2*4的向量,总和为2。
我不确定这里会不会是个bug?应该再除以 self.config.num_experts。这样全局之和为1?

self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
这里是损失函数,是希望load的真实分布和 门输出的score的分布尽可能不同。

如果两个完全一致,则 aux_loss 越大。
两个完全不一致,则 aux_loss 越小。
这个辅助损失函数是为了鼓励越小越好。即门的输出和真实使用的分布,越不同越好。

乘以 self.config.num_experts 是为了均一化损失函数。

比如

(1/8 * 1/8) * 8 = 1/8 = 0.125

专家越多 → 这个值越小
专家 = 16 → 0.0625
专家 = 32 → 0.03125
结果:
aux_loss 会变得超级小,几乎等于 0!
乘以个 self.config.num_experts ,就会让损失函数变成1


温度、重复性惩罚、top_k、top_p、随机化采样

这些也是经常容易混淆的地方,也是我们模型设置的常用的参数

for _ in range(max_new_tokens):
    past_len = past_key_values[0][0].shape[1] if past_key_values else 0  # past_keY_values 是层数的列表,比如8.然后每个元素是tuple,(k,v),其中k的是1,seq,4, 96  ,其中每次循环seq会+1,这样下下面的输入永远是最新的
    outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs)
    attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None
    logits = outputs.logits[:, -1, :] / temperature
    if repetition_penalty != 1.0:
        for i in range(input_ids.shape[0]):
            seen = torch.unique(input_ids[i])
            score = logits[i, seen]
            logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty)
    if top_k > 0:
        logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
        mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
        logits[mask.scatter(1, sorted_indices, mask)] = -float('inf')

温度

其中温度 temperature 可以看到,是
logits = outputs.logits[:, -1, :] / temperature
temperature越大,比如2,则logits都会压缩变小,原本的各个词之间的差距,变小,输出就更随意
temperature越大,比如0.5,则logits就会变的很大,差距变大,输出就会变的固定(选择概率最大的)

重复性惩罚

for i in range(input_ids.shape[0]):
    seen = torch.unique(input_ids[i])
    score = logits[i, seen]
    logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty)

seen = torch.unique(input_ids[i]) 获得已经看过的token id。
注意,这里的seen 的内容,是 id,最大时词表的大小。代表词表里的对应词的id
score = logits[i, seen] 将这些已经看过的词的预估值拿出来

repetition_penalty越大,则越惩罚重复,即鼓励不重复
举个例子:repetition_penalty 假设 为 2
logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty)

则正的数值,削弱2倍,变小,负数 * 2,变得更小。 变小就意味着不会被选中了

top_k

硬性只要最大的K个:
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')

torch.topk(logits, top_k)[0] 返回topk的数值,数值是 由大到小
[..., -1, None]] 返回每一行的最后一个值

最后,将小于这个值的,都赋值为负无穷,不要被选中

top_p

if top_p < 1.0:# 1. 把 logits 从大到小排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)# 2. 转成概率,从大到小累加,超过 p 的扔掉
    mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
    
    # 3. 关键小修正:保证至少保留一个词
    mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
    # 4. 把掩码还原回原来顺序,禁用低概率词
    logits[mask.scatter(1, sorted_indices, mask)] = -float('inf')

简单讲,先排序,再累计求和,右移一位(边界处理,右移才能确保一定大于等于),然后还原到对应的遮掩位置,将其置为 负无穷

例如第一个元素是0.6,第二个是0.24。阈值0.8
则只有第一个位置是满足的,其实是第一个和第二个元素满足,要右移一位。

scatter一般用于排序后的还原,将mask的元素,通过 indices 还原到对应的位置上

随机化采样

next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True)

如果开启随机化采样
按照softmax结果的概率值进行采样,返回对应元素的下标
torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)

否则
只取最大的
torch.argmax(logits, dim=-1, keepdim=True)

最后给你一个超级直观对比
表格

方法含义优点常用值
Top-K保留前 K 个概率最高词简单20~50
Top-P保留概率总和≥P 的词智能、灵活0.9
Temperature控制随机性创造力0.7
Repetition Penalty防止重复不复读1.1

结尾:感谢大佬 jingyaogong 开源的代码,让我能够学习这么多

github.com/jingyaogong…

感谢up主:算法魔法师 的讲解,非常清晰,强烈推荐

出错啦! - bilibili.comspace.bilibili.com/3546972082408260?spm_id_from=333.788.upinfo.head.click

最后,文章有一些硬核,写的也比较潦草(从笔记里改了改),如果有疑问可以留言。