NVIDIA Nemotron Nano 2:精准高效的混合Mamba-Transformer推理模型详解

326 阅读26分钟

论文标题:NVIDIA Nemotron Nano 2: An Accurate and Efficient Hybrid Mamba-Transformer Reasoning Model
发布机构:NVIDIA
论文地址arxiv.org/abs/2508.14…

一、核心亮点/重要结论

  1. 架构创新:采用混合Mamba-Transformer架构,替换Transformer中大部分自注意力层为Mamba-2层,在长推理链生成场景下(如8k输入+16k输出)实现推理速度跃升,同时保持精度。

  2. 性能突破:最终的9B参数模型(Nemotron-Nano-9B-v2)与同规模模型(如Qwen3-8B)相比:

    • 精度:在数学(AIME24/25)、代码(LiveCodeBench)、长上下文(RULER-128k)等推理基准上相当或更优
    • 吞吐量:生成密集场景(1k输入/8k输出、8k输入/16k输出)下实现3×~6×更高推理吞吐量(单NVIDIA A10G GPU,bfloat16精度)。
  3. 硬件友好性:通过剪枝与蒸馏,实现单A10G GPU(22GiB内存)上支持128k tokens长上下文推理(bfloat16精度),突破小显存硬件的长序列处理限制。

  4. 数据与训练优化

    • 预训练:基于20万亿tokens,结合高质量 curated数据(网页、数学、代码)与synthetic数据(STEM、多语言QA、推理题),提升多领域推理能力;
    • 对齐:通过多阶段SFT、GRPO、DPO、RLHF,平衡工具调用、长上下文与对话能力;
    • 预算控制:支持指定“思考token数”,模型可在限制思考步骤内输出正确结果,且避免格式错乱。
  5. 开源贡献:在Hugging Face开源3个模型 checkpoint(9B对齐版、9B基础版、12B基础版)及大部分预训练/后训练数据集,降低社区研究门槛。

二、研究背景与问题提出

2.1 大模型推理任务的核心挑战

当前大语言模型(LLM)在复杂推理场景(如数学证明、代码生成、长文档分析)中面临两大核心矛盾:

  1. 精度与效率的平衡

    • Transformer架构的自注意力机制(计算复杂度O(n²))在长序列(如16k+ tokens)推理时,计算量与KV缓存内存占用急剧增加,导致吞吐量骤降;
    • 纯Mamba模型虽在长序列上效率更高(O(n)复杂度),但在需要精细语义关联的推理任务(如数学逻辑链、工具调用)中精度往往不及Transformer。
  2. 硬件适配性

    • 大参数模型(如12B+)的权重与长上下文KV缓存需大量显存,普通GPU(如A10G 22GiB)难以支持128k级长序列推理,限制落地场景。
  3. 数据质量与推理能力

    • 现有预训练数据中,数学公式、多语言内容、长文档逻辑链的高质量提取难度大(如传统工具会失真LaTeX公式);
    • 推理任务需要“可控思考过程”(如指定思考token数),但现有模型缺乏针对性训练,易出现“思考超时”或“结果补偿”(用更多输出token弥补思考不足)。

2.2 研究目标

  1. 设计混合架构:融合Mamba的长序列效率与Transformer的推理精度,提升推理任务吞吐量;
  2. 构建高质量训练体系:通过curated+synthetic数据解决数学、多语言、长上下文数据质量问题;
  3. 实现高效压缩:在保证精度的前提下,将模型压缩至可在单A10G GPU上运行128k上下文;
  4. 支持可控推理:让模型能按用户指定的“思考预算”(token数)输出结果,同时保证格式正确。

三、模型架构设计(Nemotron-Nano-12B-v2-Base)

Nemotron Nano 2的基础模型是Nemotron-Nano-12B-v2-Base(12B参数),其架构继承自NVIDIA此前的Nemotron-H混合模型,核心是“少量自注意力层+大量Mamba-2层+FFN层”的组合,平衡效率与精度。

3.1 核心架构参数(表1)

架构维度具体配置
总层数62层
层类型分布自注意力层:6层(~8%)、Mamba-2层:28层、FFN层:28层
模型隐藏维度(d_model)5120
FFN隐藏维度(d_ffn)20480
注意力机制分组查询注意力(GQA):40个查询头(Q-heads)、8个键值头(KV-heads)
Mamba-2配置8个分组(groups)、状态维度(state dim)128、头维度64、膨胀因子2、卷积窗口4
激活函数平方ReLU
归一化RMSNorm,无偏置
位置嵌入无(依赖Mamba-2的时序建模能力与注意力层的上下文关联)
权重与输出层嵌入层与输出层权重分离,线性层无偏置,无dropout

3.2 层模式与结构设计

模型的层按固定规律重复,确保自注意力层均匀分散以保留关键语义关联能力,具体模式为:

  • 基础单元循环:Mamba-2层 ×3 → FFN层 ×6 → 自注意力层 ×1 → Mamba-2层 ×1 → FFN层 ×1

  • 设计逻辑:

    1. 8%的自注意力层(6层)均匀分布在62层中,负责捕捉全局语义关联(如推理链中的因果关系);
    2. 大量Mamba-2层(28层)负责长序列时序建模,降低计算复杂度;
    3. FFN层(28层)用平方ReLU激活,增强非线性表达能力,适配数学、代码等复杂领域的推理需求。

3.3 架构创新点解析

  1. 无位置嵌入:不同于传统Transformer依赖绝对/相对位置嵌入,Mamba-2通过状态空间模型(SSM)的时序建模能力处理序列顺序,同时避免位置嵌入在长序列上的泛化问题;
  2. GQA注意力:相比多头注意力(MHA),GQA用更少的KV头(8个)降低KV缓存内存占用,同时用40个Q头保留查询的精细度,平衡内存与精度;
  3. 平方ReLU激活:相比普通ReLU,平方ReLU能增强输出的平滑性与梯度传递效率,在数学计算类任务中表现更稳定(论文验证其在MATH数据集上比ReLU高2~3个百分点)。

四、预训练:20万亿tokens的高质量数据与优化训练

预训练是模型推理能力的基础,论文用20万亿tokens构建了“curated+synthetic”双轨数据体系,并通过FP8训练、长上下文扩展等优化,确保模型在多领域的基础能力。

4.1 预训练数据体系:curated数据(高质量筛选)

curated数据是从真实场景中筛选的高质量数据,覆盖通用、数学、代码、多语言四大领域,解决“数据噪声高、格式失真”问题。

4.1.1 通用网页数据:Nemotron-CC-v2

  • 来源:基于Nemotron-CC(Su et al., 2025)更新,新增8个Common Crawl快照(2024-33至2025-13)、CC-NEWS(截至2025年4月23日);

  • 处理流程:

    1. 合成重写:用Qwen3-30B-A3B替换此前的Mistral Nemo 12B,对网页内容进行重写以提升质量;
    2. 去重与过滤:全局模糊去重(MinHash LSH),CC-NEWS仅保留英文数据,无额外过滤(避免丢失知识);
  • 作用:提升模型的通用知识覆盖,更新知识截止到2025年中。

4.1.2 多语言数据:15种语言的精细筛选

  • 语言覆盖:阿拉伯语、中文、丹麦语、荷兰语、法语、德语、意大利语、日语、韩语、波兰语、葡萄牙语、俄语、西班牙语、瑞典语、泰语;
  • 数据来源:3个Common Crawl快照(2024-51、2025-08、2025-18)、多语言维基百科、FineWeb-2(Penedo et al., 2025);
  • 处理难点:缺乏可靠的多语言质量分类器,因此采用“启发式过滤”(参考英文低质量数据过滤逻辑,但禁用部分高误报率规则,如泰语的字符长度过滤);
  • 关键结论(表2:多语言数据消融实验):
    合成的DiverseQA数据(尤其是DiverseQA-crawl)效果远优于原始爬取数据,平均Global-MMLU得分达47.0,比Common Crawl(37.0)高10个百分点。
多语言数据源平均得分西班牙语德语法语日语波兰语韩语
Common Crawl(爬取)37.037.836.539.835.337.538.8
FineWeb-2(爬取)35.138.835.034.333.036.035.3
DiverseQA-wiki(合成)42.144.841.341.841.042.340.3
DiverseQA-crawl(合成)47.049.850.848.344.549.042.0

4.1.3 数学数据:Nemotron-CC-Math-v1(133B tokens)

核心痛点:传统数学数据提取工具(如OpenWebMath、MegaMath)会丢失LaTeX公式、MathML结构,导致数学推理能力下降。
解决方案:构建全新高保真提取 pipeline:

  1. URL聚合:从InfiMM-WebMath、OpenWebMath等6个现有数学数据集收集数学相关URL;
  2. 原始HTML重爬:从98个Common Crawl快照(2014-2024)中重新获取原始HTML,避免中间工具失真;
  3. 结构保留渲染:用lynx文本浏览器渲染页面,完整保留公式布局与代码格式;
  4. 标准化处理:用Phi-4(14B参数模型)移除无关内容(如广告)、将公式统一为LaTeX格式、修正符号不一致;
  5. 质量筛选:用FineMath分类器(Allal et al., 2025)保留高质量文档,MinHash LSH去重,LLM Decontaminator(Yang et al., 2023)去除基准污染数据;
  • 最终成果:133B token的Nemotron-CC-Math-3+,以及52B token的高质量子集Nemotron-CC-Math-4+(仅保留Top得分样本);
  • 效果:在MATH-500、HumanEval+、MMLU-Pro等基准上超越所有现有开源数学数据集,MATH Level 5(最难数学题)得分提升15%+。

4.1.4 代码数据:GitHub高质量筛选

  • 来源:GitHub开源代码,覆盖11种编程语言(Python、Java、C++、Go等);

  • 处理流程:

    1. 许可证过滤:仅保留允许商用的许可证(如Apache-2.0、MIT,完整列表见论文附录A),排除GPL等强copyleft许可证;
    2. 去重:先精确去重(文件哈希),再模糊去重(MinHash LSH,避免相似代码重复训练);
    3. 质量过滤:采用OpenCoder(Huang et al., 2025)的启发式规则,过滤空文件、注释占比过高(>80%)、语法错误的代码;
  • 作用:提升模型的代码生成与调试能力,在HumanEval+、MBPP+基准上表现优于Qwen3-8B。

4.2 预训练数据体系:synthetic数据(针对性生成)

synthetic数据是通过大模型生成的“高适配性数据”,弥补curated数据在推理、多语言、长上下文场景的不足,共生成约200B tokens。

4.2.1 STEM数据:科学与数学推理增强

  • 种子数据:88.6k个STEM问题,来自GSM8K、MATH、AOPS(数学竞赛)、Stemez、OpenStax开源教材;

  • 生成策略:用Qwen2.5-VL-72B-Instruct提取教材习题(忽略需图片的题,公式转为LaTeX),再用4个模型(Qwen3-30B-A3B、Qwen3-235B-A22B、Deepseek-R1、Deepseek V3)生成三类新题:

    1. 相似题:考察相同概念,但数值/场景不同;
    2. 更难题:增加逻辑步骤(如从“单变量方程”到“多变量方程组”);
    3. 不同类型题:如从“选择题”转为“计算题”;
  • 多语言扩展:将GSM8K的部分数据翻译为15种目标语言,并用目标语言添加结论句(如西班牙语“La respuesta es ...”),确保数学符号与语言一致。

4.2.2 数学MIND数据集:结构化对话增强

  • 背景:原始MIND数据集(Akter et al., 2024)基于低质量OpenWebMath(14.7B tokens)生成,推理能力有限;
  • 改进方案:用高质量的Nemotron-CC-Math-4+(52B tokens)作为源数据,用Phi-4生成7类结构化数学对话(教师-学生、辩论、访谈等);
  • 成果:73B token的新MIND数据集,在MMLU-STEM、MATH基准上比原始MIND提升8~10个百分点,验证“输入数据质量决定生成效果”。

4.2.3 基础推理SFT数据:逻辑与阅读理解增强

核心目标:解决模型在“多干扰项选择”中的推理能力不足(如MMLU-Pro需从10个选项中选正确答案)。

  • 种子数据集:

    1. LSAT(美国法学院入学考试):含逻辑推理、阅读理解、分析推理三类题;
    2. LogiQA:来自中国公务员考试的逻辑题;
    3. AQuA-RAT:代数文字题(Ling et al., 2017);
  • 生成流程:

    1. 用DeepSeek-V3和Qwen3-30B-A3B生成相似题,要求模型“避免表面修改,需推导解题步骤”;
    2. 对每个生成题,用DeepSeek-V3生成完整思考链(CoT);
    3. 多数投票筛选:仅保留“多模型生成结果一致”的样本,确保正确性;
  • 成果:8.2B tokens数据,在MMLU-Pro上提升12.12个百分点(表3),验证其对逻辑推理的增强作用。

模型平均数学得分平均代码得分平均常识推理得分MMLUMMLU-Pro
Nemotron-H 8B(无FR-SFT)37.9259.4971.7972.6744.24
Nemotron-H 8B(有FR-SFT)39.7059.6171.4372.9856.36

4.2.4 其他synthetic数据

  • 多语言Diverse QA:从多语言维基生成QA对,或翻译英文Diverse QA,确保15种语言的问答质量;
  • 学术QA:从本科/研究生级技术文档(数学、化学、医学)中提取512token片段,用e5-large嵌入后存入Milvus,再生成多 choice/自由回答QA(带答案解释);
  • SFT风格数据:覆盖代码SFT(解题)、数学SFT(推理)、MMLU-SFT(知识)、通用指令SFT,提前让模型适应对齐阶段的任务格式。

4.3 数据混合与训练阶段

为避免模型偏科,论文采用“三阶段 curriculum 训练”,按“多样性→高质量→长上下文”逐步优化:
[插入图3]

阶段1(Phase 1):多样性优先(训练前60% tokens)

  • 目标:让模型接触广泛领域,避免早期过拟合;
  • 数据占比:crawl-medium(18.3%)、crawl-medium-high(20.0%)、code(14.8%)、syn-crawl-high(16.2%)、crawl-high(11.1%)、multilingual(5.0%)、academic(4.4%)、math(3.2%)、stem-sft(3.1%)。

阶段2(Phase 2):高质量优先(训练60%~90% tokens)

  • 目标:强化核心领域(数学、代码、长文本)能力;
  • 数据占比:syn-crawl-high(21.0%)、code(20.0%)、crawl-high(16.0%)、math(9.5%)、multilingual(5.0%)、stem-sft(14.5%)、code-sft(4.4%)、crawl++(4.4%)、academic(3.8%)、wiki(0.9%)。

阶段3(Phase 3):推理与长文本优先(训练最后10% tokens)

  • 目标:为后续长上下文扩展做准备;
  • 数据占比:stem-sft(32.0%)、code(16.0%)、syn-crawl-high(12.7%)、math(11.0%)、multilingual(4.4%)、crawl-high(10.9%)、code-sft(10.0%)。

4.4 预训练优化:FP8训练与超参数

4.4.1 FP8训练方案(DeepSeek改进版)

  • 核心需求:20万亿tokens训练量需兼顾速度与精度,FP8比BF16内存占用减少50%,训练速度提升30%+;

  • 具体配置:

    1. 数据格式:所有张量用E4M3(4位指数+3位尾数),平衡精度与范围;
    2. 量化块:权重用128x128块量化,激活用1x128 tile量化,减少量化误差;
    3. 精度保留:首尾4个矩阵乘法用BF16(避免输入/输出层精度损失),优化器状态用FP32(保证更新稳定性);
  • 效果:无训练不稳定问题,最终模型精度比BF16训练仅低0.5%以内。

4.4.2 关键超参数

超参数配置
训练token总量20万亿
序列长度8192(Phase 1-3)、524288(512k,Phase LC长上下文扩展)
全局batch size768(6,029,312 tokens/批,无ramp-up)
学习率调度WSD(Warmup-Stable-Decay):稳定期4.5e-4,最低4.5e-6,最后3.6T token衰减
权重衰减0.1
Adam优化器β₁=0.9,β₂=0.95
并行策略8路张量并行(TP)+ 16路数据并行(DP)

4.5 长上下文扩展(Phase LC)

核心目标:让模型支持128k推理,但论文选择用512k序列长度训练,理由是“更长训练序列可减少长文档被切割,提升上下文连贯性”。

  • 训练细节:

    1. 序列长度:524,288(512k);
    2. 并行策略:8路张量并行+16路上下文并行(将长序列拆分到不同GPU处理,避免单卡内存溢出);
    3. 数据:在Phase 3基础上,将20%数据替换为“长文档QA”(从学术文档中拆分1024token片段,生成QA后拼接回文档);
    4. 训练量:18.9B tokens;
  • 效果(表4):512k训练+合成数据的RULER-128k得分达81.04,比256k训练(无合成数据)高10.85个百分点。

训练序列长度是否用合成数据RULER-128k得分
128k73.68
256k70.19
256k79.04
512k81.04

4.6 基础模型评估(表5、表6)

Nemotron-Nano-12B-v2-Base(12B)与9B剪枝版(N-Nano-V2 9B Base)在多基准上超越同规模模型(Qwen3-8B、Gemma3-12B):

通用与数学能力(表5)

任务N-Nano-V2 12B BaseN-Nano-V2 9B BaseQwen3 8B BaseGemma3 12B Base
MMLU(通用知识)78.2474.5376.4473.61
MMLU-Pro(难知识)63.9859.4356.2745.12
GSM8K CoT(数学)91.6691.3684.0074.45
MATH(数学)83.5480.5055.4042.40
MATH Level 5(难题)67.6163.6429.9117.71
AIME 2024 pass@3256.6730.0020.0016.67

多语言能力(表6)

任务语言N-Nano-V2 12B BaseN-Nano-V2 9B BaseQwen3 8B BaseGemma3 12B Base
Global-MMLU-Lite(平均)-75.1369.9472.8171.88
MGSM(多语言数学)西班牙语93.2091.6086.4074.00
MGSM中文44.4075.2028.8026.80
MGSM(平均)-80.0084.8064.5357.13

五、对齐:从基础模型到“可控推理模型”

对齐阶段的目标是将12B基础模型转化为“能遵循指令、支持工具调用、可控思考”的实用模型,流程为:Base → SFT1 → SFT2 → SFT3 → DPO → RLHF → GRPO → 模型融合,共训练约90B tokens。

image.png

5.1 对齐数据体系(表7)

对齐数据以“单轮prompt-响应”为主,含推理链,覆盖6大领域,确保模型适配不同任务场景。

数据领域样本数量核心来源与处理
数学1.5M复用预训练数学数据,用DeepSeek-R1-0528生成思考链响应
代码1.1MGitHub代码+合成QA,响应含代码解释与调试步骤
科学2.0MSTEM数据集+学术论文,响应含科学原理推导
工具调用400Kxlam-function-calling-60k、When2Call等,生成单轮/多轮/多步工具调用对话
对话1.5MLMSYS、HelpSteer2/3、WildChat1M,用Qwen3-235B生成自然对话响应
安全5.0MNemotron Content Safety V2、HarmfulTasks,用DeepSeek-R1生成安全响应+guard模型过滤
多语言2K翻译上述数据至5种语言(西、法、德、意、日)

5.1.1 工具调用数据:多场景模拟

工具调用是推理任务的核心能力,论文设计“三角色模拟”生成高质量数据:

  • User-Agent:模拟用户,审查工具列表、提出复杂查询、判断任务成功与否;
  • Assistant-Agent:模拟模型,根据查询调用工具、解析结果、与用户交互;
  • API-Server-Agent:模拟API服务器,检查参数合法性,返回正确结果或错误信息;
  • 质量控制:用规则验证层确保工具调用格式正确(如参数类型匹配),仅保留成功轨迹。

5.1.2 预算控制数据:截断推理链

为实现“指定思考token数”功能,论文在SFT3阶段加入5%“截断推理链数据”:

  • 生成逻辑:保留完整问题与最终答案,但将中间思考链截断至1~2k tokens;
  • 作用:让模型学习“在有限思考步骤内收敛到正确答案”,避免推理超时。

5.2 多阶段对齐训练

5.2.1 SFT阶段:分三步优化核心能力

  1. SFT1:全领域覆盖

    • 数据:完整对齐数据集(含10%“无推理链数据”,让模型支持“直接回答”模式);
    • 优化:将样本拼接为128k tokens序列,减少padding,强化长上下文学习;
    • 目标:让模型适应指令格式,覆盖数学、代码、科学等基础领域。
  2. SFT2:工具调用修复

    • 问题:SFT1的128k拼接导致工具调用格式错乱(如参数缺失);
    • 方案:不拼接样本,用完整工具调用数据集+其他领域子集训练;
    • 目标:修复工具调用精度,BFCL v3(工具调用基准)得分从55%提升至66%。
  3. SFT3:长上下文与预算控制

    • 数据:加入长上下文文档QA(128k序列)+截断推理链数据;
    • 目标:强化128k上下文理解,让模型适应“思考预算”约束。

5.2.2 IFeval RL:指令遵循精度提升

  • 数据:从LMSYS Chat数据集采样16k prompt,添加IFeval风格指令(如“用3步解释答案”);
  • 奖励信号:用规则验证器评分(如“是否分3步”“是否包含关键术语”);
  • 效果:IFeval(严格指令遵循)得分从85%提升至89.81,接近Qwen3-14B(91.32)。

5.2.3 DPO:工具调用与多步推理增强

  • 基准:BFCL v3(强调多步工具调用,如“先查天气再订机票”);

  • 环境:WorkBench(多步可验证工具调用环境,通过数据库状态对比判断正确性);

  • 流程:

    1. 对SFT3 checkpoint,生成每个WorkBench prompt的“成功轨迹”(正例)与“失败轨迹”(负例);
    2. 用DPO优化模型,让模型偏好正例;
  • 效果:BFCL v3得分从66%提升至66.98,与Qwen3-8B(66.34)相当。

5.2.4 RLHF与GRPO:对话与推理平衡

  • 基准:Arena-Hard(评估对话自然度与帮助性);

  • 数据:HelpSteer3英文上下文,生成“带思考链”与“无思考链”两种响应;

  • 奖励模型:Qwen-based模型,评分维度包括“正确性”“自然度”“帮助性”;

  • 算法:用GRPO替代传统PPO,减少训练波动;

  • 问题:RLHF提升对话能力,但数学推理得分下降3%;

  • 解决方案:模型融合(checkpoint插值),公式为:

    实验发现α=0.5时,推理与对话能力平衡最佳(ArenaHard得分74,接近Qwen3-8B的78.4)。

5.3 对齐后评估(表8)

Nemotron-Nano-v2-12B(对齐后的12B模型)在推理基准上超越Qwen3-8B,部分接近Qwen3-14B:

评估任务Nemotron-Nano-v2-12BQwen3-8BQwen3-14B
AIME-2024(数学竞赛)85.4275.8381.53
AIME-2025(数学竞赛)76.2569.3166.60
MATH-500(数学)97.7596.3096.85
LiveCodeBench(代码)70.7959.5063.08
RULER @ 128k(长上下文)83.3674.1373.55
BFCL v3(工具调用)66.9866.3468.01

5.4 预算控制评估

论文设计“思考token数限制”功能,用户可指定模型生成多少token的思考链后必须输出答案,核心效果如下:

  1. 无补偿token:未训练截断数据的模型会“用更多输出token补偿思考不足”(如思考被限制1k token,输出用3k token写答案),而训练后模型无此问题;
  2. 格式正确:未训练模型在预算耗尽时,会重复生成闭合标签(如多次输出</think>),训练后“格式正确率”保持95%以上(仅一个闭合标签);
  3. 精度稳定:即使思考token从8k减少到2k,MATH-500得分仅下降2%(从97.75到95.5),验证预算控制的实用性。

六、剪枝与蒸馏:从12B到9B,适配单A10G GPU

基础模型(12B)的bfloat16权重需22.9GiB内存,超过A10G(22GiB)的容量,且128k上下文的KV缓存会进一步占用内存。论文基于Minitron策略(Muralidharan et al., 2024)进行压缩,核心是“剪枝+蒸馏”,在保证精度的前提下减少参数与内存占用。

6.1 压缩目标与约束

  • 内存约束:单A10G GPU(22GiB),bfloat16精度,支持128k序列+batch=1,预留5%框架缓冲(1.1GiB)+1.3GiB视觉编码器内存,实际预算19.66GiB;
  • 吞吐量约束:8k输入+16k输出场景下,vLLM吞吐量不低于150 token/s/GPU;
  • 精度约束:推理基准平均得分不低于原始12B模型的95%。

6.2 重要性估计:决定“剪哪些部分”

剪枝的核心是“移除对精度影响最小的组件”,论文对层、FFN神经元、Mamba头、嵌入通道分别计算重要性:

6.2.1 层重要性:MSE迭代评估

  • 流程:对每个层,临时移除后计算“原始模型logits与剪枝模型logits的MSE”,MSE越低说明该层越不重要;
  • 策略:迭代移除MSE最低的层,直到达到目标深度;
  • 结果:原始62层,剪至56层时精度损失最小(表9),若继续剪至52层,平均推理得分从51.48降至44.92(下降12.7%)。
层数平均推理精度
5244.92
5447.35
5651.48

6.2.2 FFN与嵌入通道重要性:激活值聚合

FFN层公式为:为平方ReLU),神经元重要性通过激活值聚合计算:

  • 校准数据:1024个样本;

  • 得分公式:对的第i行(对应第i个神经元),计算激活值的均值或L2范数:

    (B=batch维度,S=序列维度);

  • 嵌入通道重要性:类似FFN,通过LayerNorm输出的激活值聚合计算。

6.2.3 Mamba头重要性:组内排序

Mamba-2层含多个分组,每个分组含多个头,剪枝需保留组内结构:

  • 通道得分:对 投影的激活值,计算每个通道的 L2 范数:

  • 头得分:对每个头,计算其所有通道得分的 L2 范数:

  • 组内排序:在每个 Mamba 组内对头部排序,移除得分最低的头;

  • 结论:剪 Mamba 头的精度损失比剪 FFN 大,因此优先剪 FFN(最终保留所有 Mamba 头)。

6.3 轻量级NAS:选择最优架构

通过“枚举候选→筛选最优”的方式,找到满足内存与精度约束的架构:

6.3.1 候选枚举

搜索空间包括:

  • 深度:52~56层;
  • 嵌入通道:4480~5120;
  • FFN维度:13440~20480;
  • Mamba头:112~128;
  • 共生成数百个候选,仅保留内存≤19.66GiB的候选。

6.3.2 最优架构选择(表10)

固定深度为56层(精度最高),对宽度剪枝后的候选进行19B token蒸馏,评估精度与吞吐量:

候选层数嵌入通道FFN维度Mamba头参数(B)平均精度吞吐量(token/s)
候选1564480179201128.9259.07161.02
候选2564480156801288.8963.02156.42
候选3564800144001208.9762.94155.86
  • 选择候选2:精度最高(63.02),吞吐量满足要求(156.42),参数8.89B(四舍五入为9B,即Nemotron-Nano-9B-v2)。

6.4 蒸馏再训练:恢复剪枝精度损失

剪枝会导致精度下降,论文通过“知识蒸馏(KD)”从12B教师模型向9B学生模型传递知识,分阶段进行:

6.4.1 推理模型蒸馏(对齐后模型)

  1. 深度剪枝:剪至56层,用60B token KD(8k序列);
  2. 宽度剪枝:嵌入通道4480,FFN 15680,用50B token KD(8k)+25B token KD(49k)+1B token KD(262k);
  3. 对齐恢复:DPO→GRPO→0.4B token KD(262k)→RLHF→模型融合;
  • 数据集:70% SFT2数据+30%预训练数据(表11,该比例精度最高)。
推理SFT数据占比预训练数据占比平均数学精度
50%50%57.5
70%30%58.5
90%10%57.2

6.4.2 基础模型蒸馏(未对齐模型)

  1. 深度剪枝:剪至56层,用120B token KD(8k);
  2. 宽度剪枝:同候选2,用360B token KD(8k);
  3. 长上下文增强:用2.5B token KD(524k);
  • 数据集:100%预训练数据,确保基础能力。

6.5 压缩效果验证

  • 内存占用:9B模型bfloat16权重+128k KV缓存(batch=1)共占用20.8GiB,低于A10G的22GiB;
  • 吞吐量:8k输入+16k输出场景下,吞吐量达156.42 token/s,是Qwen3-8B(25 token/s)的6.26×;
  • 精度:在MATH-500、RULER-128k等基准上,仅比12B模型低3%以内,优于Qwen3-8B(表5、表8)。

七、结论与开源资源

7.1 核心贡献总结

  1. 架构创新:混合Mamba-Transformer架构,实现推理任务“高精度+高吞吐量”的平衡;
  2. 数据体系:构建curated+synthetic双轨数据,解决数学、多语言、长上下文数据质量问题;
  3. 压缩策略:基于Minitron的剪枝+蒸馏,让9B模型在单A10G上支持128k推理;
  4. 可控推理:通过截断训练实现“思考预算控制”,支持工具调用与多语言;
  5. 开源生态:开放模型与数据集,降低社区研究门槛。

7.2 开源资源(Hugging Face)

  1. 模型checkpoint

    • NVIDIA-Nemotron-Nano-9B-v2:对齐+剪枝的推理模型;
    • NVIDIA-Nemotron-Nano-9B-v2-Base:剪枝后的基础模型;
    • NVIDIA-Nemotron-Nano-12B-v2-Base:12B预训练基础模型;
  2. 数据集

    • Nemotron-PreTraining-Dataset-v1:6万亿token预训练数据(含Nemotron-CC-v2、Nemotron-CC-Math-v1等);
    • Nemotron-Post-Training-Dataset-v2:对齐数据集(支持5种语言,链接待更新)。

八、关键术语解释(便于理解)

  1. Mamba-2:基于状态空间模型(SSM)的序列建模层,计算复杂度O(n),长序列效率高于Transformer;
  2. GQA(分组查询注意力) :将查询头分组,共享键值头,平衡内存(KV缓存减少)与精度;
  3. FP8/BF16:浮点数格式,FP8(8位)比BF16(16位)内存占用少50%,适合大模型训练;
  4. KV缓存:Transformer推理时缓存key和value,避免重复计算,但长序列时占用大量内存;
  5. 知识蒸馏(KD) :用大模型(教师)指导小模型(学生)训练,保留精度的同时减少参数;
  6. DPO(直接偏好优化) :通过“偏好数据”(正例vs负例)优化模型,无需显式奖励模型;
  7. GRPO(组相对策略优化) :RLHF的改进算法,减少训练波动,适合推理模型;
  8. RULER-128k:长上下文基准,评估模型在128k序列上的语义理解能力。