ML训练流水线隐患深度拆解(0):Mask全谱系——你的模型在偷看答案

4 阅读14分钟

Mask全谱系——你的模型在偷看答案

系列引子:「ML训练流水线隐患深度拆解」是一个面向中高级深度学习工程师的技术系列,以DL为主,共7篇。

本系列的出发点来自本人团队中经常碰到的问题(vibe不vibe都会存在这些问题):训练流水线中存在大量"静默错误"——它们不触发异常、不打印警告,训练 loss 曲线看起来也在正常下降,但最终交付的模型或存在严重缺陷。

同时,审查起来很令人头大,哪怕借助cc、cursor等工具,模型开到max模式,在稍微复杂一点的项目中也经常会漏查。

arXiv 论文 TrainCheck(2506.14813)对 88 个静默错误案例的研究显示,根因中用户代码与框架层各占 32%;论文评测部分另外复现了 20 个真实错误案例,其中 18 个可在单次训练迭代内被检测到。

本篇定位:系列第1篇,也是后续一切讨论的地基。我们先把 Transformer 中所有 Mask 的语义梳理清楚,再聚焦两个最容易出错的方向:Padding Mask 的传递问题,以及 Causal Mask 的方向错误(⚠️ 高危,训练 loss 极低但推理质量极差)。

个人声明:本篇内容确是和LLM配合产出,本人负责内容骨架组织和皮肉审核。(一个令人生草的事实:有实习生给我发信息的时候,每一句都用qwen做优化xs


一、Transformer 中的七类 Mask

很多工程师在调试 attention 问题时,第一个困惑来自命名:为什么同一件事在不同地方叫不同的名字?为什么传进去的 mask 好像没有生效?

根源在于,Transformer 体系里存在七类语义完全不同的 Mask,它们在不同框架、不同论文、不同代码库里使用着不完全一致的名称和取值约定。把它们混用,就是播下静默错误的种子。

Mask 类型语义目的典型取值约定典型 Shape广播维度
Padding Mask忽略 padding token 对注意力的贡献PyTorch MHA: True=屏蔽;HF: 1=保留, 0=屏蔽(B, S)常扩展到 (B, 1, 1, S),屏蔽 key 维度
Attention Mask(HF 习惯)HuggingFace 中对可见性/padding 的统一入口1=保留,0=屏蔽(B, S) 或框架内部 4D框架内部转换后加到 attention score
Causal / Autoregressive Mask防止 query 看到未来的 keyMHA/常见手写实现:True=屏蔽;SDPA bool mask:True=保留;float mask:-inf=屏蔽(T, S)(1, 1, T, S)广播到 (B, H, T, S)
Loss Mask / Token Weight Mask只对特定 token 计算损失1=计入损失,0=忽略(B, S)逐元素乘以 per-token loss
MLM / Span / Whole-Word MaskBERT 式预训练的随机遮盖1=被遮盖需预测,0=原始 token(B, S)作用于输入替换/标签构造,不直接作用于 attention
Cross-Attention MaskEncoder-Decoder 中阻止 decoder 关注 encoder paddingHF 常见:1=保留;PyTorch 风格 bool:True=屏蔽(B, 1, T_q, T_k)(B, T_k)广播到 (B, H, T_q, T_k)
Sliding Window / Local Mask只允许关注局部窗口内 token实现相关,常见为 bool 或加性 mask(T, S)(1, 1, T, S)广播到 (B, H, T, S)

这七类 Mask 中,Padding Mask、Causal Mask、Loss Mask 是日常训练中最常出错的三类。本篇聚焦前两类;Loss Mask 将在后续篇章讨论。


二、2D → 4D 转换:一次 unsqueeze 写错,全部广播到错误维度

PyTorch 内部在处理 key_padding_mask 时,会将 (B, S) 的 2D mask 转换为 4D,再与 attention score (B, H, T, S) 相加2:

# PyTorch 内部实际做法(来自 torch/nn/functional.py)
key_padding_mask = (
    key_padding_mask.view(bsz, 1, 1, src_len)   # (B, 1, 1, S)
    .expand(-1, num_heads, -1, -1)               # (B, H, 1, S)
    .reshape(bsz * num_heads, 1, src_len)        # (B*H, 1, S)
)

关键点:src_len(key 序列长度)被放在最后一维。这样与 attention score (B*H, T, S) 相加时,mask 沿 T 维广播——每个 query 位置都看到相同的 key-side padding。

自己手写 mask 广播时,最常见的错误是搞反维度:

# ❌ 错误:把 S 放到了倒数第二维(query 的位置)
mask_4d = padding_mask.unsqueeze(1).unsqueeze(-1)  # (B, 1, S, 1)
# 效果:屏蔽的是 query 维度,而非 key 维度
# 后果:每个 query 位置被屏蔽,而 padding key 照常参与计算# ✅ 正确:S 放到最后一维(key 的位置)
mask_4d = padding_mask.unsqueeze(1).unsqueeze(2)   # (B, 1, 1, S)
# 或者更清晰的写法:
mask_4d = padding_mask[:, None, None, :]           # (B, 1, 1, S)

验证广播是否正确的最快方式:在单样本批次上打印 mask,检查 padding 位置是否在 key 维度(最后一维)被置为 -inf


三、框架命名不一致:True 和 1 的意思完全相反

这是一个工业界反复踩坑的陷阱,也是跨框架迁移代码时最容易引入静默错误的地方。

接口 / 框架参数名bool / 0-1 语义float mask 语义备注
PyTorch nn.MultiheadAttentionkey_padding_maskTrue=屏蔽float 时直接加到对应位置官方文档明确如此
PyTorch nn.MultiheadAttentionattn_maskTrue=不允许 attend-inf 或极小值表示屏蔽masked_fill(True, -inf) 习惯一致
PyTorch F.scaled_dot_product_attentionattn_maskTrue=允许 attendfloat mask 直接加到 score与 MHA 的 bool 极性相反
HuggingFace Transformersattention_mask1=保留,0=屏蔽内部会转成大负数 bias现代版本常用 torch.finfo(dtype).min
JAX / Flaxmask常见约定为 True=保留,False=屏蔽实现相关与 PyTorch MHA 易混淆
PaddlePaddleattn_mask实现相关常见为加性 mask使用时需看具体 API

PyTorch key_padding_mask(True=屏蔽)与 HuggingFace attention_mask(1=保留)的语义完全相反。 这是最容易造成混淆的组合。

HuggingFace 的处理链路在不同版本中实现细节略有变化,但核心语义稳定不变:1=保留,0=屏蔽,屏蔽位置最终会被转换成加到 attention score 上的极大负数 bias4-6:

# HF 内部转换(现代版本可近似理解为)
# 输入:attention_mask,1=保留,0=屏蔽
# 转换后:保留位置加 0,屏蔽位置加一个极大负数
extended_mask = (1.0 - attention_mask[:, None, None, :]).to(dtype)
extended_mask = extended_mask.masked_fill(
    extended_mask.to(torch.bool),
    torch.finfo(dtype).min
)
attention_scores = attention_scores + extended_mask
​
# 备注:较早版本中,很多模型代码会直接写成 * -10000.0;
# 原理相同,都是让 softmax 后对应概率接近 0。

四、Padding Mask 错误及检测

错误一:忘记传递 attention_mask

HuggingFace 的 tokenizer 会返回 attention_mask,但部分教程示例只取了 input_ids

# ❌ 错误:只传 input_ids,padding 位置会参与 attention 计算
encoding = tokenizer(texts, padding=True, return_tensors="pt")
output = model(encoding["input_ids"])
​
# ✅ 正确:传入完整 encoding,包括 attention_mask
output = model(**encoding)
# 等价于:
output = model(
    input_ids=encoding["input_ids"],
    attention_mask=encoding["attention_mask"]
)

HuggingFace 的 BERT 等模型在未提供 attention_mask 时,会自动创建一个全 1 的 mask——即所有 token 包括 padding 都参与计算4-6。在单序列推理时这不明显,在 batch 推理时 padding token 会影响非 padding 位置的 attention 权重,导致不同 batch size 下结果不一致。

错误二:传递了极性相反的 mask

跨框架迁移时,直接把 PyTorch 风格的 mask(True=屏蔽)传给 HuggingFace 接口:

# ❌ 错误:PyTorch 语义的 mask 传给 HF 模型
# padding 位置为 True,非 padding 位置为 False
pt_style_mask = (input_ids == tokenizer.pad_token_id)  # True=padding
output = model(input_ids, attention_mask=pt_style_mask)
# 后果:有效 token 被屏蔽,padding token 反而被保留# ✅ 正确:HF 的 attention_mask,1=有效,0=padding
hf_style_mask = (input_ids != tokenizer.pad_token_id).long()
output = model(input_ids, attention_mask=hf_style_mask)
# 或直接用 tokenizer 的输出:
output = model(**tokenizer(texts, padding=True, return_tensors="pt"))

检测方法

快速验证:构造一个含 padding 的 batch,将同一序列分别在独立批次和含 padding 的批次中推理,比较 logits:

import torch
​
def check_padding_mask(model, tokenizer, text):
    """验证 padding mask 是否正确生效"""
    # 单独推理
    single = tokenizer(text, return_tensors="pt")
    out_single = model(**single).logits
​
    # 与 padding 一起推理
    padded = tokenizer(
        [text, "x"],  # 第二个序列极短,会产生 padding
        padding=True,
        return_tensors="pt"
    )
    out_batched = model(**padded).logits[0, :out_single.shape[1]]
​
    max_diff = (out_single - out_batched).abs().max().item()
    print(f"最大差异:{max_diff:.6f}")
    # 如果 > 1e-4,说明 padding mask 未正确生效
    assert max_diff < 1e-4, "Padding mask 可能存在问题!"

五、Causal Mask ⚠️ 高危

Causal Mask 相关错误是本篇最值得重视的内容。方向写反时,训练 loss 可以非常低(因为模型可以直接看到下一个 token),但推理质量极差——这正是典型的静默错误特征。

上三角 vs 下三角:为什么容易混淆

Attention score 矩阵 (T, S) 中,行是 query,列是 key。Causal attention 要求:位置 i 的 query 只能关注 j ≤ i 的 key(即左边和自身),不能关注未来的 j > i

这意味着需要屏蔽上三角(不含对角线)的位置。但"upper triangular"这个描述容易引发混淆:

  • torch.triu(..., diagonal=1):返回上三角矩阵(对角线上方),这些是需要屏蔽的位置
  • torch.tril(...):返回下三角矩阵(含对角线),这些是允许关注的位置

混淆来源:有些实现用 tril 生成"允许关注的 mask",有些用 triu 生成"需要屏蔽的 mask",两者语义完全等价,但如果在使用 masked_fill 时没有对应好,就会反向屏蔽7。

三种错误实现对比

import torch
​
T = 4  # 序列长度# ❌ 错误版本1:屏蔽了下三角(允许关注未来,屏蔽历史)
# 训练时模型可以看到所有未来 token,loss 极低
# 推理时模型的 attention 分布与训练完全不同,输出质量极差
wrong_mask_1 = torch.tril(torch.ones(T, T)).bool()  # 下三角为 True
attn_scores.masked_fill(wrong_mask_1, float('-inf'))
# 等价错误:用 triu 但 diagonal=0(屏蔽了对角线和下方)# ❌ 错误版本2:triu 时 diagonal 参数错误
# diagonal=0 会把对角线(自注意力)也屏蔽掉,query 看不到自身
wrong_mask_2 = torch.triu(torch.ones(T, T), diagonal=0).bool()
attn_scores.masked_fill(wrong_mask_2, float('-inf'))
# 位置 0 的 query 连自己都看不到,attention 全为 -inf,softmax 后为 NaN# ❌ 错误版本3:mask 的 bool 极性理解错误
# masked_fill 在 mask=True 时填充,但错误地认为 True=保留
wrong_mask_3 = torch.tril(torch.ones(T, T), diagonal=0).bool()
# 正确逻辑应该是对上三角(True)填充 -inf
# 如果传入下三角(True=保留位置),就是把历史全部屏蔽
attn_scores.masked_fill(wrong_mask_3, float('-inf'))
​
# ✅ 正确版本:上三角(不含对角线)填充 -inf
causal_mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
# triu(diagonal=1) 返回:对角线上方为 True,其余为 False
# masked_fill(True) → 这些位置填 -inf → 正确屏蔽未来
correct_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
​
print(causal_mask)
# tensor([[False,  True,  True,  True],
#         [False, False,  True,  True],
#         [False, False, False,  True],
#         [False, False, False, False]])

视觉验证:正确的 causal mask 应该是右上角为 True(被屏蔽),左下角含对角线为 False(可以关注)。

影响分析:为什么训练 loss 低不代表模型对

Causal mask 写反时(屏蔽历史,允许未来),模型训练时可以直接看到目标 token,相当于答案泄露。loss 会快速下降到接近理论最小值,gradient 也显得正常。

推理时不存在未来信息,模型的 attention 分布与训练时完全不同,输出质量会退化到随机水平或更差。这类问题经常要训练数天后才会在下游评测中暴露,是最典型的高代价静默错误之一。

PyTorch seqlen_q ≠ seqlen_kv 的额外陷阱

在推理的 KV-cache 场景中,query 序列长度(T_q)往往为 1(只有新 token),而 key/value 序列长度(T_k)是完整的历史长度。此时 causal mask 的 shape 需要是 (T_q, T_k),而非 (T, T)

# ❌ 错误:始终用方形 mask,在 KV-cache 推理时 shape 不匹配或语义错误
T = seq_len
causal_mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
# 推理时 T_q=1,T_k=past_len,上面的 mask 形状不对# ✅ 正确:根据实际 T_q 和 T_k 动态生成
def make_causal_mask(T_q: int, T_k: int, device) -> torch.Tensor:
    """
    返回 causal mask,shape (T_q, T_k)
    mask[i, j] = True 表示 query i 不能关注 key j
    条件:j > (T_k - T_q + i),即未来的 key 位置
    """
    q_idx = torch.arange(T_q, device=device).unsqueeze(1)  # (T_q, 1)
    k_idx = torch.arange(T_k, device=device).unsqueeze(0)  # (1, T_k)
    # query 在原始序列中的绝对位置
    q_abs = T_k - T_q + q_idx
    return k_idx > q_abs  # True=屏蔽(未来的 key)

最佳实践:优先使用 is_causal=True

PyTorch 2.0+ 的 scaled_dot_product_attention 提供 is_causal 参数,由内核负责生成正确的 causal mask,并在 FlashAttention 等后端上融合计算:

import torch.nn.functional as F

# ✅ 推荐:让框架处理 causal mask
out = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=True   # 框架自动处理上三角屏蔽,适配 seqlen_q != seqlen_kv
)

# 注意:is_causal=True 与显式 attn_mask 不能同时使用
# 如需同时处理 padding mask,需在外部合并后传入 attn_mask,
# 并将 is_causal 设为 False

### 额外注意:PyTorch MHA 与 SDPA 的 bool mask 极性并不一致

这是 PyTorch 体系内部一个非常容易踩坑的细节:

- `nn.MultiheadAttention` 中,bool mask 通常是 **True=屏蔽**
- `F.scaled_dot_product_attention` 中,bool `attn_mask` 则是 **True=参与 attention**

也就是说,**同样是 PyTorch,换了一个 API,bool mask 的语义可能正好相反**。跨实现迁移时,不能只看变量名叫不叫 `attn_mask`,必须重新核对文档与单元测试。

Causal Mask 检测方法

def verify_causal_mask(model_forward_fn, tokenizer, device="cpu"):
    """
    检测 causal mask 是否正确。
    原理:正确的 causal LM 中,位置 i 的 logits 不应受位置 i+1 及以后 token 影响。
    """
    text = "The quick brown fox"
    tokens = tokenizer(text, return_tensors="pt").to(device)
    input_ids = tokens["input_ids"]  # (1, T)

    # 基准:正常前向
    with torch.no_grad():
        logits_orig = model_forward_fn(input_ids).logits  # (1, T, V)

    # 篡改最后一个 token
    input_ids_modified = input_ids.clone()
    input_ids_modified[0, -1] = tokenizer.unk_token_id

    with torch.no_grad():
        logits_modified = model_forward_fn(input_ids_modified).logits

    # 如果 causal mask 正确,修改位置 T-1 的 token
    # 不应影响位置 0 到 T-2 的 logits
    diff = (logits_orig[0, :-1] - logits_modified[0, :-1]).abs().max().item()
    print(f"修改最后 token 对前序位置 logits 的最大影响:{diff:.6f}")
    
    if diff > 1e-4:
        print("⚠️ 警告:Causal mask 可能存在信息泄漏!")
    else:
        print("✅ Causal mask 验证通过")

六、快速参考:Mask 使用决策树

需要屏蔽 padding token?
  ├── 使用 PyTorch 原生 API → key_padding_mask,True=屏蔽
  └── 使用 HuggingFace → attention_mask,1=保留,0=屏蔽
      └── 直接传 tokenizer(**encoding) 最安全

需要因果遮盖(自回归生成)?
  ├── PyTorch 2.0+(推荐)→ scaled_dot_product_attention(is_causal=True)
  ├── 手动构造 → triu(ones(T,T), diagonal=1).bool(),True 位置填 -inf
  └── 推理 KV-cache(T_q ≠ T_k)→ 必须动态生成 (T_q, T_k) 形状的 mask

跨框架迁移时 → 检查 mask 极性,True/1 的语义可能完全相反

总结

本篇梳理了 Transformer 体系中七类 Mask 的语义边界,重点拆解了三个高频陷阱:

  1. 2D→4D 广播维度unsqueeze 位置决定屏蔽的是 key 还是 query,写错了全部广播到错误方向
  2. 框架极性冲突:PyTorch True=屏蔽 vs HuggingFace 1=保留,跨框架迁移必须显式核验
  3. Causal Mask 方向写反(最高危):训练 loss 低不等于模型对,推理阶段才会暴露,损失极大

Mask 类错误的共同特征是症状延迟——错误在训练时埋入,但往往在评测或上线后才暴露。有效防御手段是编写如上所示的单元测试,在每次模型结构变更后运行验证。


参考资料

[1]  Yuxuan Jiang et al., "Training with Confidence: Catching Silent Errors in Deep Learning Training with Automated Proactive Checks," arXiv:2506.14813, 2025. arxiv.org/abs/2506.14…

[2]  PyTorch 源码,torch/nn/functional.pymulti_head_attention_forward 函数中 key_padding_maskview/expand/reshape 处理。github.com/pytorch/pyt…

[3]  PyTorch 官方文档,torch.nn.MultiheadAttentionkey_padding_mask 参数说明:"For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention." docs.pytorch.org/docs/stable…

[4]  HuggingFace Transformers 讨论,"Clarification on the attention_mask",确认 HF attention_mask 中 1=保留、0=屏蔽,框架内部反转后加到 attention score。discuss.huggingface.co/t/clarifica…

[5]  HuggingFace Transformers 源代码,最新实现中内部转换已经使用torch.finfo(dtype).min而不是-10000github.com/huggingface…

[6]  HuggingFace Transformers 讨论,"Do automatically generated attention masks ignore padding?",确认未传递 attention_mask 时模型自动生成全 1 mask,不处理 padding。discuss.huggingface.co/t/do-automa…

[7]  PyTorch 讨论,"Should Transformer's causal attention mask be upper-triangular or lower-triangular?",明确 PyTorch generate_square_subsequent_mask 在上三角放置 -infdiscuss.pytorch.org/t/should-tr…


下篇预告

第2篇:FP16 数值稳定性——你的梯度在悄悄消失

Mask 错误会让模型"偷看答案";而 FP16 的数值问题则更隐蔽——梯度在不知不觉中下溢为零,或 loss 在某次迭代后突然跳变到 NaN。下篇将拆解 FP16/BF16 的表示范围差异、loss_scale 的工作原理与失效条件、混合精度训练中常见的精度劣化模式,以及如何用 gradient norm 监控提前发现问题。