本文将讲解 MindSpore 中两个高频核心知识点:
Stop Gradient 梯度截断:屏蔽指定张量的梯度回传,消除无关张量对梯度计算的影响;
has_aux 辅助数据参数:自动处理多输出函数的梯度计算,无需手动截断梯度;
这两个知识点是解决复杂场景梯度计算的核心。
问题引入:多输出函数的梯度计算陷阱
默认情况下,如果前向函数只返回 loss 一个值,mindspore.grad 只会计算「loss 对指定参数的梯度」,这也是我们训练模型的核心诉求。
但如果前向函数返回多个输出项(如 loss + logits 预测值),MindSpore 的微分函数会默认计算:所有输出项对指定参数的梯度之和,这会导致最终的梯度值失真,与我们需要的「仅 loss 求梯度」的结果不一致!
实战验证:多输出函数的梯度失真问题
# 定义返回 loss + z(预测值) 的多输出函数
def function_with_logits(x, y, w, b):
z = ops.matmul(x, w) + b
loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
return loss, z # 输出项1:loss,输出项2:预测值z
# 生成微分函数,依旧对w(2)、b(3)求导
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print("多输出函数的梯度值:\n", grads)
运行结果:
多输出函数的梯度值:
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00],
[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00],
[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00],
[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00],
[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00]]), Tensor(shape=[3], dtype=Float32, value= [ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00]))
结果对比:
单输出函数(仅 loss):w 的梯度值约为 0.326、0.0159、0.0422;
多输出函数(loss+z):w 的梯度值约为 1.326、1.0159、1.0422;
梯度值完全不同,这就是「多输出项梯度叠加」导致的失真,这不是我们想要的结果!
解决方案一:Stop Gradient 手动梯度截断【核心 API】
- Stop Gradient 核心作用
MindSpore 提供 mindspore.ops.stop_gradient 接口,是梯度计算中的「截断利器」,核心功能有 3 个:
对指定 Tensor 进行梯度截断,消除该 Tensor 对梯度计算的所有影响;
屏蔽无关输出项的梯度回传,让微分函数只计算「目标项(loss)」的梯度;
阻止梯度从当前 Tensor 流向计算图的上游节点,不改变 Tensor 的数值,仅改变梯度传播属性。
核心特性:stop_gradient(z) 只会修改 z 的梯度传播标记,不会改变 z 的数值本身,我们依然可以正常获取和使用 z 的值,只是它不再参与梯度计算。
- 实战:使用 Stop Gradient 修正梯度计算
只需要对不需要参与梯度计算的输出项(本例中的 z)包裹stop_gradient,即可实现「仅 loss 求梯度」:
def function_stop_gradient(x, y, w, b):
z = ops.matmul(x, w) + b
loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
return loss, ops.stop_gradient(z) # 对z进行梯度截断
# 生成微分函数并求梯度
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print("梯度截断后的梯度值:\n", grads)
运行结果:
梯度截断后的梯度值:
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02]]), Tensor(shape=[3], dtype=Float32, value= [ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02]))
结果验证:此时的梯度值与「单输出函数仅返回 loss」的梯度值完全一致,问题完美解决!
解决方案二:has_aux=True 自动处理辅助数据【推荐最佳实践】
- 辅助数据(Auxiliary data)定义
在 MindSpore 的自动微分体系中,辅助数据 特指:前向函数中「除第一个输出项外的其他所有输出项」。
行业通用约定:前向函数的第一个返回值必须是损失值 loss,其余返回值均为辅助数据(如预测值、中间特征、准确率等)。
我们训练模型的核心诉求永远是「求 loss 对参数的梯度」,辅助数据只是为了监控训练过程,不需要参与梯度计算。
- has_aux 参数的核心能力
mindspore.grad 和 mindspore.value_and_grad 都提供了 has_aux 布尔型参数,当设置 has_aux=True 时:
自动将函数的「第一个输出项」作为梯度计算的唯一目标(仅求 loss 的梯度);
自动对「所有辅助数据」执行梯度截断(等价于手动加stop_gradient);
微分函数的返回值会拆分为「梯度结果 + 辅助数据元组」,无需手动处理;
语法更简洁,无需修改原函数的返回逻辑,是处理多输出函数的最优解。
-
实战:has_aux=True 优雅实现梯度计算 + 辅助数据返回
复用未做任何修改的多输出函数 function_with_logits
def function_with_logits(x, y, w, b): z = ops.matmul(x, w) + b loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z)) return loss, z
仅需添加 has_aux=True,无需手动截断梯度
grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True) grads, (z,) = grad_fn(x, y, w, b) # 解构:梯度 + 辅助数据 print("梯度值(与单输出一致):\n", grads) print("辅助数据z(预测值):\n", z)
运行结果:
梯度值(与单输出一致):
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02]]), Tensor(shape=[3], dtype=Float32, value= [ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02]))
辅助数据z(预测值):
[ 3.8211915 -2.994512 -1.932323 ]
两大方案对比与选型建议
Stop Gradient:适合「精细化梯度控制」,比如只对函数中某一个中间张量截断梯度,而非所有辅助数据;灵活性高,适合复杂场景;
has_aux=True:适合「标准多输出场景」,只要满足「第一个返回值是 loss」的约定,无脑使用即可;简洁高效,推荐优先使用;
核心总结
多输出函数的默认梯度计算是「所有输出项梯度之和」,会导致梯度失真,必须做梯度截断处理;
stop_gradient 是梯度截断的基础 API,核心是「消除指定 Tensor 的梯度影响,不改变数值」;
has_aux=True 是辅助数据的最优解,自动截断辅助数据梯度,推荐在标准场景中使用;
梯度截断的核心目的:让模型的梯度计算始终围绕「损失函数」展开,保证参数更新的正确性。