论文标题:NVIDIA Nemotron Nano 2: An Accurate and Efficient Hybrid Mamba-Transformer Reasoning Model
发布机构:NVIDIA
论文地址:arxiv.org/abs/2508.14…
一、核心亮点/重要结论
-
架构创新:采用混合Mamba-Transformer架构,替换Transformer中大部分自注意力层为Mamba-2层,在长推理链生成场景下(如8k输入+16k输出)实现推理速度跃升,同时保持精度。
-
性能突破:最终的9B参数模型(Nemotron-Nano-9B-v2)与同规模模型(如Qwen3-8B)相比:
- 精度:在数学(AIME24/25)、代码(LiveCodeBench)、长上下文(RULER-128k)等推理基准上相当或更优;
- 吞吐量:生成密集场景(1k输入/8k输出、8k输入/16k输出)下实现3×~6×更高推理吞吐量(单NVIDIA A10G GPU,bfloat16精度)。
-
硬件友好性:通过剪枝与蒸馏,实现单A10G GPU(22GiB内存)上支持128k tokens长上下文推理(bfloat16精度),突破小显存硬件的长序列处理限制。
-
数据与训练优化:
- 预训练:基于20万亿tokens,结合高质量 curated数据(网页、数学、代码)与synthetic数据(STEM、多语言QA、推理题),提升多领域推理能力;
- 对齐:通过多阶段SFT、GRPO、DPO、RLHF,平衡工具调用、长上下文与对话能力;
- 预算控制:支持指定“思考token数”,模型可在限制思考步骤内输出正确结果,且避免格式错乱。
-
开源贡献:在Hugging Face开源3个模型 checkpoint(9B对齐版、9B基础版、12B基础版)及大部分预训练/后训练数据集,降低社区研究门槛。
二、研究背景与问题提出
2.1 大模型推理任务的核心挑战
当前大语言模型(LLM)在复杂推理场景(如数学证明、代码生成、长文档分析)中面临两大核心矛盾:
-
精度与效率的平衡:
- Transformer架构的自注意力机制(计算复杂度O(n²))在长序列(如16k+ tokens)推理时,计算量与KV缓存内存占用急剧增加,导致吞吐量骤降;
- 纯Mamba模型虽在长序列上效率更高(O(n)复杂度),但在需要精细语义关联的推理任务(如数学逻辑链、工具调用)中精度往往不及Transformer。
-
硬件适配性:
- 大参数模型(如12B+)的权重与长上下文KV缓存需大量显存,普通GPU(如A10G 22GiB)难以支持128k级长序列推理,限制落地场景。
-
数据质量与推理能力:
- 现有预训练数据中,数学公式、多语言内容、长文档逻辑链的高质量提取难度大(如传统工具会失真LaTeX公式);
- 推理任务需要“可控思考过程”(如指定思考token数),但现有模型缺乏针对性训练,易出现“思考超时”或“结果补偿”(用更多输出token弥补思考不足)。
2.2 研究目标
- 设计混合架构:融合Mamba的长序列效率与Transformer的推理精度,提升推理任务吞吐量;
- 构建高质量训练体系:通过curated+synthetic数据解决数学、多语言、长上下文数据质量问题;
- 实现高效压缩:在保证精度的前提下,将模型压缩至可在单A10G GPU上运行128k上下文;
- 支持可控推理:让模型能按用户指定的“思考预算”(token数)输出结果,同时保证格式正确。
三、模型架构设计(Nemotron-Nano-12B-v2-Base)
Nemotron Nano 2的基础模型是Nemotron-Nano-12B-v2-Base(12B参数),其架构继承自NVIDIA此前的Nemotron-H混合模型,核心是“少量自注意力层+大量Mamba-2层+FFN层”的组合,平衡效率与精度。
3.1 核心架构参数(表1)
| 架构维度 | 具体配置 |
|---|---|
| 总层数 | 62层 |
| 层类型分布 | 自注意力层:6层(~8%)、Mamba-2层:28层、FFN层:28层 |
| 模型隐藏维度(d_model) | 5120 |
| FFN隐藏维度(d_ffn) | 20480 |
| 注意力机制 | 分组查询注意力(GQA):40个查询头(Q-heads)、8个键值头(KV-heads) |
| Mamba-2配置 | 8个分组(groups)、状态维度(state dim)128、头维度64、膨胀因子2、卷积窗口4 |
| 激活函数 | 平方ReLU |
| 归一化 | RMSNorm,无偏置 |
| 位置嵌入 | 无(依赖Mamba-2的时序建模能力与注意力层的上下文关联) |
| 权重与输出层 | 嵌入层与输出层权重分离,线性层无偏置,无dropout |
3.2 层模式与结构设计
模型的层按固定规律重复,确保自注意力层均匀分散以保留关键语义关联能力,具体模式为:
-
基础单元循环:
Mamba-2层 ×3 → FFN层 ×6 → 自注意力层 ×1 → Mamba-2层 ×1 → FFN层 ×1 -
设计逻辑:
- 8%的自注意力层(6层)均匀分布在62层中,负责捕捉全局语义关联(如推理链中的因果关系);
- 大量Mamba-2层(28层)负责长序列时序建模,降低计算复杂度;
- FFN层(28层)用平方ReLU激活,增强非线性表达能力,适配数学、代码等复杂领域的推理需求。
3.3 架构创新点解析
- 无位置嵌入:不同于传统Transformer依赖绝对/相对位置嵌入,Mamba-2通过状态空间模型(SSM)的时序建模能力处理序列顺序,同时避免位置嵌入在长序列上的泛化问题;
- GQA注意力:相比多头注意力(MHA),GQA用更少的KV头(8个)降低KV缓存内存占用,同时用40个Q头保留查询的精细度,平衡内存与精度;
- 平方ReLU激活:相比普通ReLU,平方ReLU能增强输出的平滑性与梯度传递效率,在数学计算类任务中表现更稳定(论文验证其在MATH数据集上比ReLU高2~3个百分点)。
四、预训练:20万亿tokens的高质量数据与优化训练
预训练是模型推理能力的基础,论文用20万亿tokens构建了“curated+synthetic”双轨数据体系,并通过FP8训练、长上下文扩展等优化,确保模型在多领域的基础能力。
4.1 预训练数据体系:curated数据(高质量筛选)
curated数据是从真实场景中筛选的高质量数据,覆盖通用、数学、代码、多语言四大领域,解决“数据噪声高、格式失真”问题。
4.1.1 通用网页数据:Nemotron-CC-v2
-
来源:基于Nemotron-CC(Su et al., 2025)更新,新增8个Common Crawl快照(2024-33至2025-13)、CC-NEWS(截至2025年4月23日);
-
处理流程:
- 合成重写:用Qwen3-30B-A3B替换此前的Mistral Nemo 12B,对网页内容进行重写以提升质量;
- 去重与过滤:全局模糊去重(MinHash LSH),CC-NEWS仅保留英文数据,无额外过滤(避免丢失知识);
-
作用:提升模型的通用知识覆盖,更新知识截止到2025年中。
4.1.2 多语言数据:15种语言的精细筛选
- 语言覆盖:阿拉伯语、中文、丹麦语、荷兰语、法语、德语、意大利语、日语、韩语、波兰语、葡萄牙语、俄语、西班牙语、瑞典语、泰语;
- 数据来源:3个Common Crawl快照(2024-51、2025-08、2025-18)、多语言维基百科、FineWeb-2(Penedo et al., 2025);
- 处理难点:缺乏可靠的多语言质量分类器,因此采用“启发式过滤”(参考英文低质量数据过滤逻辑,但禁用部分高误报率规则,如泰语的字符长度过滤);
- 关键结论(表2:多语言数据消融实验):
合成的DiverseQA数据(尤其是DiverseQA-crawl)效果远优于原始爬取数据,平均Global-MMLU得分达47.0,比Common Crawl(37.0)高10个百分点。
| 多语言数据源 | 平均得分 | 西班牙语 | 德语 | 法语 | 日语 | 波兰语 | 韩语 |
|---|---|---|---|---|---|---|---|
| Common Crawl(爬取) | 37.0 | 37.8 | 36.5 | 39.8 | 35.3 | 37.5 | 38.8 |
| FineWeb-2(爬取) | 35.1 | 38.8 | 35.0 | 34.3 | 33.0 | 36.0 | 35.3 |
| DiverseQA-wiki(合成) | 42.1 | 44.8 | 41.3 | 41.8 | 41.0 | 42.3 | 40.3 |
| DiverseQA-crawl(合成) | 47.0 | 49.8 | 50.8 | 48.3 | 44.5 | 49.0 | 42.0 |
4.1.3 数学数据:Nemotron-CC-Math-v1(133B tokens)
核心痛点:传统数学数据提取工具(如OpenWebMath、MegaMath)会丢失LaTeX公式、MathML结构,导致数学推理能力下降。
解决方案:构建全新高保真提取 pipeline:
- URL聚合:从InfiMM-WebMath、OpenWebMath等6个现有数学数据集收集数学相关URL;
- 原始HTML重爬:从98个Common Crawl快照(2014-2024)中重新获取原始HTML,避免中间工具失真;
- 结构保留渲染:用lynx文本浏览器渲染页面,完整保留公式布局与代码格式;
- 标准化处理:用Phi-4(14B参数模型)移除无关内容(如广告)、将公式统一为LaTeX格式、修正符号不一致;
- 质量筛选:用FineMath分类器(Allal et al., 2025)保留高质量文档,MinHash LSH去重,LLM Decontaminator(Yang et al., 2023)去除基准污染数据;
- 最终成果:133B token的Nemotron-CC-Math-3+,以及52B token的高质量子集Nemotron-CC-Math-4+(仅保留Top得分样本);
- 效果:在MATH-500、HumanEval+、MMLU-Pro等基准上超越所有现有开源数学数据集,MATH Level 5(最难数学题)得分提升15%+。
4.1.4 代码数据:GitHub高质量筛选
-
来源:GitHub开源代码,覆盖11种编程语言(Python、Java、C++、Go等);
-
处理流程:
- 许可证过滤:仅保留允许商用的许可证(如Apache-2.0、MIT,完整列表见论文附录A),排除GPL等强copyleft许可证;
- 去重:先精确去重(文件哈希),再模糊去重(MinHash LSH,避免相似代码重复训练);
- 质量过滤:采用OpenCoder(Huang et al., 2025)的启发式规则,过滤空文件、注释占比过高(>80%)、语法错误的代码;
-
作用:提升模型的代码生成与调试能力,在HumanEval+、MBPP+基准上表现优于Qwen3-8B。
4.2 预训练数据体系:synthetic数据(针对性生成)
synthetic数据是通过大模型生成的“高适配性数据”,弥补curated数据在推理、多语言、长上下文场景的不足,共生成约200B tokens。
4.2.1 STEM数据:科学与数学推理增强
-
种子数据:88.6k个STEM问题,来自GSM8K、MATH、AOPS(数学竞赛)、Stemez、OpenStax开源教材;
-
生成策略:用Qwen2.5-VL-72B-Instruct提取教材习题(忽略需图片的题,公式转为LaTeX),再用4个模型(Qwen3-30B-A3B、Qwen3-235B-A22B、Deepseek-R1、Deepseek V3)生成三类新题:
- 相似题:考察相同概念,但数值/场景不同;
- 更难题:增加逻辑步骤(如从“单变量方程”到“多变量方程组”);
- 不同类型题:如从“选择题”转为“计算题”;
-
多语言扩展:将GSM8K的部分数据翻译为15种目标语言,并用目标语言添加结论句(如西班牙语“La respuesta es ...”),确保数学符号与语言一致。
4.2.2 数学MIND数据集:结构化对话增强
- 背景:原始MIND数据集(Akter et al., 2024)基于低质量OpenWebMath(14.7B tokens)生成,推理能力有限;
- 改进方案:用高质量的Nemotron-CC-Math-4+(52B tokens)作为源数据,用Phi-4生成7类结构化数学对话(教师-学生、辩论、访谈等);
- 成果:73B token的新MIND数据集,在MMLU-STEM、MATH基准上比原始MIND提升8~10个百分点,验证“输入数据质量决定生成效果”。
4.2.3 基础推理SFT数据:逻辑与阅读理解增强
核心目标:解决模型在“多干扰项选择”中的推理能力不足(如MMLU-Pro需从10个选项中选正确答案)。
-
种子数据集:
- LSAT(美国法学院入学考试):含逻辑推理、阅读理解、分析推理三类题;
- LogiQA:来自中国公务员考试的逻辑题;
- AQuA-RAT:代数文字题(Ling et al., 2017);
-
生成流程:
- 用DeepSeek-V3和Qwen3-30B-A3B生成相似题,要求模型“避免表面修改,需推导解题步骤”;
- 对每个生成题,用DeepSeek-V3生成完整思考链(CoT);
- 多数投票筛选:仅保留“多模型生成结果一致”的样本,确保正确性;
-
成果:8.2B tokens数据,在MMLU-Pro上提升12.12个百分点(表3),验证其对逻辑推理的增强作用。
| 模型 | 平均数学得分 | 平均代码得分 | 平均常识推理得分 | MMLU | MMLU-Pro |
|---|---|---|---|---|---|
| Nemotron-H 8B(无FR-SFT) | 37.92 | 59.49 | 71.79 | 72.67 | 44.24 |
| Nemotron-H 8B(有FR-SFT) | 39.70 | 59.61 | 71.43 | 72.98 | 56.36 |
4.2.4 其他synthetic数据
- 多语言Diverse QA:从多语言维基生成QA对,或翻译英文Diverse QA,确保15种语言的问答质量;
- 学术QA:从本科/研究生级技术文档(数学、化学、医学)中提取512token片段,用e5-large嵌入后存入Milvus,再生成多 choice/自由回答QA(带答案解释);
- SFT风格数据:覆盖代码SFT(解题)、数学SFT(推理)、MMLU-SFT(知识)、通用指令SFT,提前让模型适应对齐阶段的任务格式。
4.3 数据混合与训练阶段
为避免模型偏科,论文采用“三阶段 curriculum 训练”,按“多样性→高质量→长上下文”逐步优化:
[插入图3]
阶段1(Phase 1):多样性优先(训练前60% tokens)
- 目标:让模型接触广泛领域,避免早期过拟合;
- 数据占比:crawl-medium(18.3%)、crawl-medium-high(20.0%)、code(14.8%)、syn-crawl-high(16.2%)、crawl-high(11.1%)、multilingual(5.0%)、academic(4.4%)、math(3.2%)、stem-sft(3.1%)。
阶段2(Phase 2):高质量优先(训练60%~90% tokens)
- 目标:强化核心领域(数学、代码、长文本)能力;
- 数据占比:syn-crawl-high(21.0%)、code(20.0%)、crawl-high(16.0%)、math(9.5%)、multilingual(5.0%)、stem-sft(14.5%)、code-sft(4.4%)、crawl++(4.4%)、academic(3.8%)、wiki(0.9%)。
阶段3(Phase 3):推理与长文本优先(训练最后10% tokens)
- 目标:为后续长上下文扩展做准备;
- 数据占比:stem-sft(32.0%)、code(16.0%)、syn-crawl-high(12.7%)、math(11.0%)、multilingual(4.4%)、crawl-high(10.9%)、code-sft(10.0%)。
4.4 预训练优化:FP8训练与超参数
4.4.1 FP8训练方案(DeepSeek改进版)
-
核心需求:20万亿tokens训练量需兼顾速度与精度,FP8比BF16内存占用减少50%,训练速度提升30%+;
-
具体配置:
- 数据格式:所有张量用E4M3(4位指数+3位尾数),平衡精度与范围;
- 量化块:权重用128x128块量化,激活用1x128 tile量化,减少量化误差;
- 精度保留:首尾4个矩阵乘法用BF16(避免输入/输出层精度损失),优化器状态用FP32(保证更新稳定性);
-
效果:无训练不稳定问题,最终模型精度比BF16训练仅低0.5%以内。
4.4.2 关键超参数
| 超参数 | 配置 |
|---|---|
| 训练token总量 | 20万亿 |
| 序列长度 | 8192(Phase 1-3)、524288(512k,Phase LC长上下文扩展) |
| 全局batch size | 768(6,029,312 tokens/批,无ramp-up) |
| 学习率调度 | WSD(Warmup-Stable-Decay):稳定期4.5e-4,最低4.5e-6,最后3.6T token衰减 |
| 权重衰减 | 0.1 |
| Adam优化器 | β₁=0.9,β₂=0.95 |
| 并行策略 | 8路张量并行(TP)+ 16路数据并行(DP) |
4.5 长上下文扩展(Phase LC)
核心目标:让模型支持128k推理,但论文选择用512k序列长度训练,理由是“更长训练序列可减少长文档被切割,提升上下文连贯性”。
-
训练细节:
- 序列长度:524,288(512k);
- 并行策略:8路张量并行+16路上下文并行(将长序列拆分到不同GPU处理,避免单卡内存溢出);
- 数据:在Phase 3基础上,将20%数据替换为“长文档QA”(从学术文档中拆分1024token片段,生成QA后拼接回文档);
- 训练量:18.9B tokens;
-
效果(表4):512k训练+合成数据的RULER-128k得分达81.04,比256k训练(无合成数据)高10.85个百分点。
| 训练序列长度 | 是否用合成数据 | RULER-128k得分 |
|---|---|---|
| 128k | 是 | 73.68 |
| 256k | 否 | 70.19 |
| 256k | 是 | 79.04 |
| 512k | 是 | 81.04 |
4.6 基础模型评估(表5、表6)
Nemotron-Nano-12B-v2-Base(12B)与9B剪枝版(N-Nano-V2 9B Base)在多基准上超越同规模模型(Qwen3-8B、Gemma3-12B):
通用与数学能力(表5)
| 任务 | N-Nano-V2 12B Base | N-Nano-V2 9B Base | Qwen3 8B Base | Gemma3 12B Base |
|---|---|---|---|---|
| MMLU(通用知识) | 78.24 | 74.53 | 76.44 | 73.61 |
| MMLU-Pro(难知识) | 63.98 | 59.43 | 56.27 | 45.12 |
| GSM8K CoT(数学) | 91.66 | 91.36 | 84.00 | 74.45 |
| MATH(数学) | 83.54 | 80.50 | 55.40 | 42.40 |
| MATH Level 5(难题) | 67.61 | 63.64 | 29.91 | 17.71 |
| AIME 2024 pass@32 | 56.67 | 30.00 | 20.00 | 16.67 |
多语言能力(表6)
| 任务 | 语言 | N-Nano-V2 12B Base | N-Nano-V2 9B Base | Qwen3 8B Base | Gemma3 12B Base |
|---|---|---|---|---|---|
| Global-MMLU-Lite(平均) | - | 75.13 | 69.94 | 72.81 | 71.88 |
| MGSM(多语言数学) | 西班牙语 | 93.20 | 91.60 | 86.40 | 74.00 |
| MGSM | 中文 | 44.40 | 75.20 | 28.80 | 26.80 |
| MGSM(平均) | - | 80.00 | 84.80 | 64.53 | 57.13 |
五、对齐:从基础模型到“可控推理模型”
对齐阶段的目标是将12B基础模型转化为“能遵循指令、支持工具调用、可控思考”的实用模型,流程为:Base → SFT1 → SFT2 → SFT3 → DPO → RLHF → GRPO → 模型融合,共训练约90B tokens。
5.1 对齐数据体系(表7)
对齐数据以“单轮prompt-响应”为主,含推理链,覆盖6大领域,确保模型适配不同任务场景。
| 数据领域 | 样本数量 | 核心来源与处理 |
|---|---|---|
| 数学 | 1.5M | 复用预训练数学数据,用DeepSeek-R1-0528生成思考链响应 |
| 代码 | 1.1M | GitHub代码+合成QA,响应含代码解释与调试步骤 |
| 科学 | 2.0M | STEM数据集+学术论文,响应含科学原理推导 |
| 工具调用 | 400K | xlam-function-calling-60k、When2Call等,生成单轮/多轮/多步工具调用对话 |
| 对话 | 1.5M | LMSYS、HelpSteer2/3、WildChat1M,用Qwen3-235B生成自然对话响应 |
| 安全 | 5.0M | Nemotron Content Safety V2、HarmfulTasks,用DeepSeek-R1生成安全响应+guard模型过滤 |
| 多语言 | 2K | 翻译上述数据至5种语言(西、法、德、意、日) |
5.1.1 工具调用数据:多场景模拟
工具调用是推理任务的核心能力,论文设计“三角色模拟”生成高质量数据:
- User-Agent:模拟用户,审查工具列表、提出复杂查询、判断任务成功与否;
- Assistant-Agent:模拟模型,根据查询调用工具、解析结果、与用户交互;
- API-Server-Agent:模拟API服务器,检查参数合法性,返回正确结果或错误信息;
- 质量控制:用规则验证层确保工具调用格式正确(如参数类型匹配),仅保留成功轨迹。
5.1.2 预算控制数据:截断推理链
为实现“指定思考token数”功能,论文在SFT3阶段加入5%“截断推理链数据”:
- 生成逻辑:保留完整问题与最终答案,但将中间思考链截断至1~2k tokens;
- 作用:让模型学习“在有限思考步骤内收敛到正确答案”,避免推理超时。
5.2 多阶段对齐训练
5.2.1 SFT阶段:分三步优化核心能力
-
SFT1:全领域覆盖
- 数据:完整对齐数据集(含10%“无推理链数据”,让模型支持“直接回答”模式);
- 优化:将样本拼接为128k tokens序列,减少padding,强化长上下文学习;
- 目标:让模型适应指令格式,覆盖数学、代码、科学等基础领域。
-
SFT2:工具调用修复
- 问题:SFT1的128k拼接导致工具调用格式错乱(如参数缺失);
- 方案:不拼接样本,用完整工具调用数据集+其他领域子集训练;
- 目标:修复工具调用精度,BFCL v3(工具调用基准)得分从55%提升至66%。
-
SFT3:长上下文与预算控制
- 数据:加入长上下文文档QA(128k序列)+截断推理链数据;
- 目标:强化128k上下文理解,让模型适应“思考预算”约束。
5.2.2 IFeval RL:指令遵循精度提升
- 数据:从LMSYS Chat数据集采样16k prompt,添加IFeval风格指令(如“用3步解释答案”);
- 奖励信号:用规则验证器评分(如“是否分3步”“是否包含关键术语”);
- 效果:IFeval(严格指令遵循)得分从85%提升至89.81,接近Qwen3-14B(91.32)。
5.2.3 DPO:工具调用与多步推理增强
-
基准:BFCL v3(强调多步工具调用,如“先查天气再订机票”);
-
环境:WorkBench(多步可验证工具调用环境,通过数据库状态对比判断正确性);
-
流程:
- 对SFT3 checkpoint,生成每个WorkBench prompt的“成功轨迹”(正例)与“失败轨迹”(负例);
- 用DPO优化模型,让模型偏好正例;
-
效果:BFCL v3得分从66%提升至66.98,与Qwen3-8B(66.34)相当。
5.2.4 RLHF与GRPO:对话与推理平衡
-
基准:Arena-Hard(评估对话自然度与帮助性);
-
数据:HelpSteer3英文上下文,生成“带思考链”与“无思考链”两种响应;
-
奖励模型:Qwen-based模型,评分维度包括“正确性”“自然度”“帮助性”;
-
算法:用GRPO替代传统PPO,减少训练波动;
-
问题:RLHF提升对话能力,但数学推理得分下降3%;
-
解决方案:模型融合(checkpoint插值),公式为:
实验发现α=0.5时,推理与对话能力平衡最佳(ArenaHard得分74,接近Qwen3-8B的78.4)。
5.3 对齐后评估(表8)
Nemotron-Nano-v2-12B(对齐后的12B模型)在推理基准上超越Qwen3-8B,部分接近Qwen3-14B:
| 评估任务 | Nemotron-Nano-v2-12B | Qwen3-8B | Qwen3-14B |
|---|---|---|---|
| AIME-2024(数学竞赛) | 85.42 | 75.83 | 81.53 |
| AIME-2025(数学竞赛) | 76.25 | 69.31 | 66.60 |
| MATH-500(数学) | 97.75 | 96.30 | 96.85 |
| LiveCodeBench(代码) | 70.79 | 59.50 | 63.08 |
| RULER @ 128k(长上下文) | 83.36 | 74.13 | 73.55 |
| BFCL v3(工具调用) | 66.98 | 66.34 | 68.01 |
5.4 预算控制评估
论文设计“思考token数限制”功能,用户可指定模型生成多少token的思考链后必须输出答案,核心效果如下:
- 无补偿token:未训练截断数据的模型会“用更多输出token补偿思考不足”(如思考被限制1k token,输出用3k token写答案),而训练后模型无此问题;
- 格式正确:未训练模型在预算耗尽时,会重复生成闭合标签(如多次输出
</think>),训练后“格式正确率”保持95%以上(仅一个闭合标签); - 精度稳定:即使思考token从8k减少到2k,MATH-500得分仅下降2%(从97.75到95.5),验证预算控制的实用性。
六、剪枝与蒸馏:从12B到9B,适配单A10G GPU
基础模型(12B)的bfloat16权重需22.9GiB内存,超过A10G(22GiB)的容量,且128k上下文的KV缓存会进一步占用内存。论文基于Minitron策略(Muralidharan et al., 2024)进行压缩,核心是“剪枝+蒸馏”,在保证精度的前提下减少参数与内存占用。
6.1 压缩目标与约束
- 内存约束:单A10G GPU(22GiB),bfloat16精度,支持128k序列+batch=1,预留5%框架缓冲(1.1GiB)+1.3GiB视觉编码器内存,实际预算19.66GiB;
- 吞吐量约束:8k输入+16k输出场景下,vLLM吞吐量不低于150 token/s/GPU;
- 精度约束:推理基准平均得分不低于原始12B模型的95%。
6.2 重要性估计:决定“剪哪些部分”
剪枝的核心是“移除对精度影响最小的组件”,论文对层、FFN神经元、Mamba头、嵌入通道分别计算重要性:
6.2.1 层重要性:MSE迭代评估
- 流程:对每个层,临时移除后计算“原始模型logits与剪枝模型logits的MSE”,MSE越低说明该层越不重要;
- 策略:迭代移除MSE最低的层,直到达到目标深度;
- 结果:原始62层,剪至56层时精度损失最小(表9),若继续剪至52层,平均推理得分从51.48降至44.92(下降12.7%)。
| 层数 | 平均推理精度 |
|---|---|
| 52 | 44.92 |
| 54 | 47.35 |
| 56 | 51.48 |
6.2.2 FFN与嵌入通道重要性:激活值聚合
FFN层公式为:(
为平方ReLU),神经元重要性通过激活值聚合计算:
-
校准数据:1024个样本;
-
得分公式:对
的第i行(对应第i个神经元),计算激活值的均值或L2范数:
(B=batch维度,S=序列维度);
-
嵌入通道重要性:类似FFN,通过LayerNorm输出的激活值聚合计算。
6.2.3 Mamba头重要性:组内排序
Mamba-2层含多个分组,每个分组含多个头,剪枝需保留组内结构:
-
通道得分:对
投影的激活值,计算每个通道的 L2 范数:
-
头得分:对每个头,计算其所有通道得分的 L2 范数:
-
组内排序:在每个 Mamba 组内对头部排序,移除得分最低的头;
-
结论:剪 Mamba 头的精度损失比剪 FFN 大,因此优先剪 FFN(最终保留所有 Mamba 头)。
6.3 轻量级NAS:选择最优架构
通过“枚举候选→筛选最优”的方式,找到满足内存与精度约束的架构:
6.3.1 候选枚举
搜索空间包括:
- 深度:52~56层;
- 嵌入通道:4480~5120;
- FFN维度:13440~20480;
- Mamba头:112~128;
- 共生成数百个候选,仅保留内存≤19.66GiB的候选。
6.3.2 最优架构选择(表10)
固定深度为56层(精度最高),对宽度剪枝后的候选进行19B token蒸馏,评估精度与吞吐量:
| 候选 | 层数 | 嵌入通道 | FFN维度 | Mamba头 | 参数(B) | 平均精度 | 吞吐量(token/s) |
|---|---|---|---|---|---|---|---|
| 候选1 | 56 | 4480 | 17920 | 112 | 8.92 | 59.07 | 161.02 |
| 候选2 | 56 | 4480 | 15680 | 128 | 8.89 | 63.02 | 156.42 |
| 候选3 | 56 | 4800 | 14400 | 120 | 8.97 | 62.94 | 155.86 |
- 选择候选2:精度最高(63.02),吞吐量满足要求(156.42),参数8.89B(四舍五入为9B,即Nemotron-Nano-9B-v2)。
6.4 蒸馏再训练:恢复剪枝精度损失
剪枝会导致精度下降,论文通过“知识蒸馏(KD)”从12B教师模型向9B学生模型传递知识,分阶段进行:
6.4.1 推理模型蒸馏(对齐后模型)
- 深度剪枝:剪至56层,用60B token KD(8k序列);
- 宽度剪枝:嵌入通道4480,FFN 15680,用50B token KD(8k)+25B token KD(49k)+1B token KD(262k);
- 对齐恢复:DPO→GRPO→0.4B token KD(262k)→RLHF→模型融合;
- 数据集:70% SFT2数据+30%预训练数据(表11,该比例精度最高)。
| 推理SFT数据占比 | 预训练数据占比 | 平均数学精度 |
|---|---|---|
| 50% | 50% | 57.5 |
| 70% | 30% | 58.5 |
| 90% | 10% | 57.2 |
6.4.2 基础模型蒸馏(未对齐模型)
- 深度剪枝:剪至56层,用120B token KD(8k);
- 宽度剪枝:同候选2,用360B token KD(8k);
- 长上下文增强:用2.5B token KD(524k);
- 数据集:100%预训练数据,确保基础能力。
6.5 压缩效果验证
- 内存占用:9B模型bfloat16权重+128k KV缓存(batch=1)共占用20.8GiB,低于A10G的22GiB;
- 吞吐量:8k输入+16k输出场景下,吞吐量达156.42 token/s,是Qwen3-8B(25 token/s)的6.26×;
- 精度:在MATH-500、RULER-128k等基准上,仅比12B模型低3%以内,优于Qwen3-8B(表5、表8)。
七、结论与开源资源
7.1 核心贡献总结
- 架构创新:混合Mamba-Transformer架构,实现推理任务“高精度+高吞吐量”的平衡;
- 数据体系:构建curated+synthetic双轨数据,解决数学、多语言、长上下文数据质量问题;
- 压缩策略:基于Minitron的剪枝+蒸馏,让9B模型在单A10G上支持128k推理;
- 可控推理:通过截断训练实现“思考预算控制”,支持工具调用与多语言;
- 开源生态:开放模型与数据集,降低社区研究门槛。
7.2 开源资源(Hugging Face)
-
模型checkpoint:
- NVIDIA-Nemotron-Nano-9B-v2:对齐+剪枝的推理模型;
- NVIDIA-Nemotron-Nano-9B-v2-Base:剪枝后的基础模型;
- NVIDIA-Nemotron-Nano-12B-v2-Base:12B预训练基础模型;
-
数据集:
- Nemotron-PreTraining-Dataset-v1:6万亿token预训练数据(含Nemotron-CC-v2、Nemotron-CC-Math-v1等);
- Nemotron-Post-Training-Dataset-v2:对齐数据集(支持5种语言,链接待更新)。
八、关键术语解释(便于理解)
- Mamba-2:基于状态空间模型(SSM)的序列建模层,计算复杂度O(n),长序列效率高于Transformer;
- GQA(分组查询注意力) :将查询头分组,共享键值头,平衡内存(KV缓存减少)与精度;
- FP8/BF16:浮点数格式,FP8(8位)比BF16(16位)内存占用少50%,适合大模型训练;
- KV缓存:Transformer推理时缓存key和value,避免重复计算,但长序列时占用大量内存;
- 知识蒸馏(KD) :用大模型(教师)指导小模型(学生)训练,保留精度的同时减少参数;
- DPO(直接偏好优化) :通过“偏好数据”(正例vs负例)优化模型,无需显式奖励模型;
- GRPO(组相对策略优化) :RLHF的改进算法,减少训练波动,适合推理模型;
- RULER-128k:长上下文基准,评估模型在128k序列上的语义理解能力。