PyTorch 架构级学习系列 - 第 3 篇
本文将揭开 PyTorch 自动微分的神秘面纱。你将理解计算图是如何动态构建的,
backward()背后发生了什么,以及如何从零实现一个支持自动微分的系统。
📚 目录
- 从链式法则到自动微分:为什么需要 Autograd?
- 计算图的动态构建:每个操作都在记录历史
- backward() 的实现原理:反向传播的机制
- grad_fn 的追踪:梯度函数链
- 自定义操作:实现 autograd.Function
- 手写微型 Autograd:从零构建自动微分系统
- 实战:训练神经网络的完整流程
🎯 Part 1: 从链式法则到自动微分
💡 核心问题
训练神经网络需要计算梯度,手动求导太痛苦。能否让计算机自动计算梯度?
1.1 问题:手动求导的痛苦
假设我们有一个简单的函数:
import torch
# 函数:y = (x^2 + 3) * (2x + 1)
def f(x):
a = x ** 2
b = a + 3
c = 2 * x
d = c + 1
y = b * d
return y
x = torch.tensor(2.0)
y = f(x)
print(f"y = {y}") # y = 35.0
问题:如何计算 dy/dx?
方法 1:手动求导(符号微分)
# 手动推导:
# y = (x^2 + 3) * (2x + 1)
# dy/dx = d/dx[(x^2 + 3) * (2x + 1)]
# = (x^2 + 3) * 2 + 2x * (2x + 1) ← 乘法法则
# = 2x^2 + 6 + 4x^2 + 2x
# = 6x^2 + 2x + 6
def df_dx(x):
return 6 * x**2 + 2 * x + 6
x = 2.0
grad = df_dx(x)
print(f"dy/dx = {grad}") # dy/dx = 34.0
问题:
- ❌ 每个新函数都要手动推导公式
- ❌ 复杂函数的导数公式极其复杂
- ❌ 神经网络有数百万个参数,无法手动求导
方法 2:数值微分
# 利用导数定义:f'(x) = lim[h→0] (f(x+h) - f(x)) / h
def numerical_gradient(f, x, h=1e-5):
return (f(x + h) - f(x)) / h
x = 2.0
grad = numerical_gradient(f, x)
print(f"dy/dx ≈ {grad}") # dy/dx ≈ 34.00001
问题:
- ❌ 精度有限(取决于 h 的大小)
- ❌ 对于 n 个参数,需要计算 n+1 次前向传播
- ❌ 神经网络有数百万参数,计算量太大
方法 3:自动微分(PyTorch 的解决方案)
import torch
x = torch.tensor(2.0, requires_grad=True) # ← 告诉 PyTorch 追踪这个变量
# 前向传播(PyTorch 自动记录计算历史)
a = x ** 2
b = a + 3
c = 2 * x
d = c + 1
y = b * d
# 反向传播(PyTorch 自动计算梯度)
y.backward()
print(f"dy/dx = {x.grad}") # dy/dx = 34.0 ← 完全正确!
优点:
- ✅ 精确(不是近似)
- ✅ 高效(只需一次反向传播)
- ✅ 自动(不需要手动推导)
1.2 链式法则:自动微分的数学基础
链式法则(Chain Rule): 复合函数的导数
如果 y = f(g(x)),那么:
dy/dx = df/dg * dg/dx
多变量的链式法则:
如果 z 依赖于多个中间变量:
z = f(x, y)
x = g(u)
y = h(u)
那么:
dz/du = (∂z/∂x) * (dx/du) + (∂z/∂y) * (dy/du)
例子:理解链式法则
# 函数:z = (x + y)^2,其中 x = 2u, y = 3u
# 前向传播
u = 1.0
x = 2 * u # x = 2
y = 3 * u # y = 3
w = x + y # w = 5
z = w ** 2 # z = 25
# 反向传播(手动计算梯度)
# dz/du = ?
# 步骤 1:z 对 w 的偏导
dz_dw = 2 * w # dz/dw = 2w = 10
# 步骤 2:w 对 x 和 y 的偏导
dw_dx = 1 # ∂w/∂x = 1
dw_dy = 1 # ∂w/∂y = 1
# 步骤 3:x 和 y 对 u 的导数
dx_du = 2 # dx/du = 2
dy_du = 3 # dy/du = 3
# 步骤 4:应用链式法则
dz_dx = dz_dw * dw_dx # = 10 * 1 = 10
dz_dy = dz_dw * dw_dy # = 10 * 1 = 10
dz_du = dz_dx * dx_du + dz_dy * dy_du # = 10*2 + 10*3 = 50
print(f"dz/du = {dz_du}") # 50
验证(用 PyTorch):
u = torch.tensor(1.0, requires_grad=True)
x = 2 * u
y = 3 * u
w = x + y
z = w ** 2
z.backward()
print(f"dz/du = {u.grad}") # tensor(50.) ← 完全一致!
1.3 计算图:可视化计算过程
计算图(Computational Graph): 用图的方式表示计算过程
u = torch.tensor(1.0, requires_grad=True)
x = 2 * u
y = 3 * u
w = x + y
z = w ** 2
前向传播的计算图:
u (1.0)
/ \
/ \
*2 *3
/ \
x (2.0) y (3.0)
\ /
\ /
\ /
+
|
w (5.0)
|
^2
|
z (25.0)
反向传播的计算图:
u
/ \
/ \
∂z/∂x ∂z/∂y
=10 =10
\ /
\ /
\ /
dx/du=2 dy/du=3
| |
+--------+
|
dz/du=50
小结:
- 手动求导:不可行
- 数值微分:太慢
- 自动微分:利用链式法则,自动计算梯度
- 计算图:可视化表示计算和梯度传播
但这引出了新问题:PyTorch 如何知道计算的历史?如何构建这个图?
这就是 Part 2 要解答的...
🔨 Part 2: 计算图的动态构建
核心洞察
PyTorch 在前向传播时动态构建计算图,每个操作都会记录历史。
2.1 requires_grad:追踪的开关
import torch
# 不追踪梯度
x = torch.tensor(2.0)
y = x ** 2
print(f"x.requires_grad: {x.requires_grad}") # False
print(f"y.requires_grad: {y.requires_grad}") # False
print(f"y.grad_fn: {y.grad_fn}") # None ← 没有梯度函数
# 追踪梯度
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
print(f"x.requires_grad: {x.requires_grad}") # True
print(f"y.requires_grad: {y.requires_grad}") # True ← 自动继承
print(f"y.grad_fn: {y.grad_fn}") # <PowBackward0> ← 有梯度函数!
关键发现:
requires_grad=True告诉 PyTorch 追踪这个 Tensor- 所有依赖它的 Tensor 自动继承
requires_grad=True - 每个操作都会创建一个
grad_fn(梯度函数)
2.2 grad_fn:操作的记录
x = torch.tensor(2.0, requires_grad=True)
a = x ** 2 # a.grad_fn = <PowBackward0>
b = a + 3 # b.grad_fn = <AddBackward0>
c = 2 * x # c.grad_fn = <MulBackward0>
d = c + 1 # d.grad_fn = <AddBackward0>
y = b * d # y.grad_fn = <MulBackward0>
print(f"a.grad_fn: {a.grad_fn}") # <PowBackward0>
print(f"b.grad_fn: {b.grad_fn}") # <AddBackward0>
print(f"y.grad_fn: {y.grad_fn}") # <MulBackward0>
grad_fn 存储了什么?
# grad_fn 存储:
# 1. 操作类型(如 Mul、Add、Pow)
# 2. 输入 Tensor 的引用
# 3. 如何计算梯度的信息
print(f"y 的梯度函数: {y.grad_fn}")
print(f"y 的输入: {y.grad_fn.next_functions}")
# ((<AddBackward0 object>, 0), (<AddBackward0 object>, 0))
# 说明 y = b * d,b 和 d 都是通过 Add 操作得到的
2.3 计算图的构建过程
让我们详细看看每一步发生了什么:
import torch
x = torch.tensor(2.0, requires_grad=True)
print("=== 初始状态 ===")
print(f"x: value={x.item()}, grad_fn={x.grad_fn}")
# x 是叶子节点,没有 grad_fn
print("\n=== 步骤 1: a = x ** 2 ===")
a = x ** 2
print(f"a: value={a.item()}, grad_fn={a.grad_fn}")
# a: value=4.0, grad_fn=<PowBackward0>
# PowBackward0 记录了:
# - 操作:power
# - 输入:x
# - 指数:2
print("\n=== 步骤 2: b = a + 3 ===")
b = a + 3
print(f"b: value={b.item()}, grad_fn={b.grad_fn}")
# b: value=7.0, grad_fn=<AddBackward0>
# AddBackward0 记录了:
# - 操作:add
# - 输入:a 和常数 3
print("\n=== 步骤 3: y = b * 2 ===")
y = b * 2
print(f"y: value={y.item()}, grad_fn={y.grad_fn}")
# y: value=14.0, grad_fn=<MulBackward0>
# MulBackward0 记录了:
# - 操作:multiply
# - 输入:b 和常数 2
计算图的结构:
x (leaf)
|
| PowBackward0
| (input: x, exponent: 2)
↓
a = x^2
|
| AddBackward0
| (input: a, other: 3)
↓
b = a + 3
|
| MulBackward0
| (input: b, other: 2)
↓
y = b * 2
2.4 动态图 vs 静态图
PyTorch(动态图):
# 每次执行都构建新的图
for i in range(3):
x = torch.tensor(float(i), requires_grad=True)
if i < 2:
y = x ** 2
else:
y = x ** 3 # ← 不同的操作,不同的图
y.backward()
print(f"i={i}, grad={x.grad}")
TensorFlow 1.x(静态图):
# 伪代码:先定义图,再执行
import tensorflow as tf
# 定义阶段:构建图(只执行一次)
x = tf.placeholder(tf.float32)
y = x ** 2
grad = tf.gradients(y, x)
# 执行阶段:运行图(多次执行)
with tf.Session() as sess:
for i in range(3):
result = sess.run([y, grad], feed_dict={x: float(i)})
对比:
| 特性 | PyTorch (动态图) | TensorFlow 1.x (静态图) |
|---|---|---|
| 图的构建 | 每次运行时构建 | 预先定义 |
| 调试 | 容易(Python 调试器) | 困难(图内部) |
| 灵活性 | 高(可以用 if/for) | 低(需要特殊操作) |
| 性能 | 稍慢(重复构建) | 快(优化过的图) |
2.5 detach():切断梯度传播
x = torch.tensor(2.0, requires_grad=True)
a = x ** 2
b = a.detach() # ← 从计算图中分离
c = b + 3
print(f"a.requires_grad: {a.requires_grad}") # True
print(f"b.requires_grad: {b.requires_grad}") # False ← 被切断了
print(f"c.requires_grad: {c.requires_grad}") # False
# b 和 c 不在计算图中
print(f"a.grad_fn: {a.grad_fn}") # <PowBackward0>
print(f"b.grad_fn: {b.grad_fn}") # None ← 没有梯度函数
print(f"c.grad_fn: {c.grad_fn}") # None
使用场景:
# 场景 1:固定部分网络
pretrained_model = load_pretrained_model()
x = pretrained_model(input)
x = x.detach() # 不更新 pretrained_model 的参数
# 场景 2:避免内存泄漏
for epoch in range(100):
loss = compute_loss()
loss.backward()
# 记录损失(不需要梯度)
loss_value = loss.detach().item() # 或 loss.item()
losses.append(loss_value)
2.6 查看完整的计算图
import torch
x = torch.tensor(2.0, requires_grad=True)
a = x ** 2
b = a + 3
y = b * 2
def print_graph(tensor, indent=0):
"""递归打印计算图"""
print(" " * indent + f"{tensor.grad_fn}")
if tensor.grad_fn is not None:
for next_fn, _ in tensor.grad_fn.next_functions:
if next_fn is not None:
# 创建一个假的 tensor 来继续遍历
print_graph_from_fn(next_fn, indent + 1)
def print_graph_from_fn(grad_fn, indent=0):
print(" " * indent + f"{grad_fn}")
if hasattr(grad_fn, 'next_functions'):
for next_fn, _ in grad_fn.next_functions:
if next_fn is not None:
print_graph_from_fn(next_fn, indent + 1)
print("=== 计算图结构 ===")
print_graph(y)
# 输出:
# <MulBackward0>
# <AddBackward0>
# <PowBackward0>
# <AccumulateGrad> ← x 的梯度累积器
小结:
requires_grad=True开启追踪- 每个操作创建一个
grad_fn grad_fn记录操作类型和输入- 计算图是动态构建的(每次运行都重新构建)
detach()可以切断梯度传播
现在我们有了计算图,如何利用它计算梯度?
这就是 Part 3 要解答的...
⚙️ Part 3: backward() 的实现原理
核心机制
backward()从输出开始,沿着计算图反向传播梯度。
3.1 简单例子:理解 backward()
import torch
x = torch.tensor(3.0, requires_grad=True)
y = x ** 2
print("=== 前向传播完成 ===")
print(f"x = {x}")
print(f"y = {y}")
print(f"x.grad = {x.grad}") # None(还没计算梯度)
print("\n=== 调用 backward ===")
y.backward()
print(f"x.grad = {x.grad}") # tensor(6.) ← dy/dx = 2x = 6
backward() 做了什么?
- 从
y开始(dy/dy = 1) - 调用
y.grad_fn.backward(1.0) PowBackward0计算 dy/dx = 2x = 6- 将梯度累积到
x.grad
3.2 链式求导:多个操作
x = torch.tensor(2.0, requires_grad=True)
a = x ** 2 # a = 4
b = a + 3 # b = 7
y = b * 2 # y = 14
y.backward()
print(f"dy/dx = {x.grad}") # tensor(8.)
手动验证:
# y = (x^2 + 3) * 2
# dy/dx = 2 * 2x = 4x
# 当 x=2 时,dy/dx = 8 ✓
反向传播的详细步骤:
# 步骤 1:初始化
# dy/dy = 1.0
# 步骤 2:y = b * 2
# dy/db = ∂(b * 2)/∂b = 2
# 传播:dy/db = dy/dy * 2 = 1.0 * 2 = 2.0
# 步骤 3:b = a + 3
# db/da = ∂(a + 3)/∂a = 1
# 传播:dy/da = dy/db * 1 = 2.0 * 1 = 2.0
# 步骤 4:a = x^2
# da/dx = ∂(x^2)/∂x = 2x = 4
# 传播:dy/dx = dy/da * 2x = 2.0 * 4 = 8.0 ✓
3.3 多输入的反向传播
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
# z = x * y + x^2
z = x * y + x ** 2
z.backward()
print(f"dz/dx = {x.grad}") # tensor(7.) = y + 2x = 3 + 4
print(f"dz/dy = {y.grad}") # tensor(2.) = x = 2
计算图:
x y
/ \ |
/ \ |
^2 \ |
| \ |
| \ |
| \ |
+ * /
\ | /
\ | /
\ |/
\ Add
\ /
\/
z
反向传播:
初始:dz/dz = 1
步骤 1:z = add(x^2, x*y)
dz/d(x^2) = 1
dz/d(x*y) = 1
步骤 2:x*y
d(x*y)/dx = y = 3
d(x*y)/dy = x = 2
传播:dz/dx += 1 * 3 = 3
dz/dy = 1 * 2 = 2
步骤 3:x^2
d(x^2)/dx = 2x = 4
传播:dz/dx += 1 * 4 = 4
最终:dz/dx = 3 + 4 = 7 ✓
dz/dy = 2 ✓
3.4 梯度累积
x = torch.tensor(2.0, requires_grad=True)
# 第一次计算
y1 = x ** 2
y1.backward()
print(f"第一次: x.grad = {x.grad}") # tensor(4.)
# 第二次计算(不清零)
y2 = x ** 3
y2.backward()
print(f"第二次: x.grad = {x.grad}") # tensor(16.) = 4 + 12 ← 累积了!
# 正确做法:每次 backward 前清零
x.grad.zero_()
y2.backward()
print(f"清零后: x.grad = {x.grad}") # tensor(12.)
为什么会累积?
# 因为 backward() 的实现是:
x.grad = x.grad + new_grad # 累加,不是赋值
# 这样设计的原因:
# 1. 支持多个损失函数
loss1.backward() # 计算第一个损失的梯度
loss2.backward() # 累加第二个损失的梯度
# total_grad = grad1 + grad2
# 2. 支持梯度检查点(gradient checkpointing)
3.5 非标量的 backward
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2 # y = [1, 4, 9]
# 错误:
# y.backward() # RuntimeError: grad can be implicitly created only for scalar outputs
# 正确:指定 gradient 参数
y.backward(torch.tensor([1.0, 1.0, 1.0]))
print(f"x.grad = {x.grad}") # tensor([2., 4., 6.]) = 2*x
为什么需要 gradient 参数?
# backward(gradient) 的含义:
# 假设有一个标量 L 依赖于 y
# gradient = dL/dy
# backward() 计算 dL/dx = (dL/dy) * (dy/dx)
# 例子:
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2
# 假设 L = sum(y) = y[0] + y[1] + y[2]
# dL/dy = [1, 1, 1]
y.backward(torch.tensor([1.0, 1.0, 1.0]))
print(x.grad) # [2., 4., 6.]
# 假设 L = 2*y[0] + 3*y[1] + 4*y[2]
# dL/dy = [2, 3, 4]
x.grad.zero_()
y.backward(torch.tensor([2.0, 3.0, 4.0]))
print(x.grad) # [4., 12., 24.] = [2*2*1, 3*2*2, 4*2*3]
3.6 retain_graph:保留计算图
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
# 第一次 backward
y.backward(retain_graph=True) # ← 保留图
print(f"第一次: x.grad = {x.grad}")
# 第二次 backward(需要先清零)
x.grad.zero_()
y.backward() # 可以再次调用
print(f"第二次: x.grad = {x.grad}")
# 如果不用 retain_graph:
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()
# y.backward() # RuntimeError: Trying to backward through the graph a second time
为什么默认不保留图?
- 节省内存:计算图可能很大
- 大多数情况只需要一次 backward
小结:
backward()从输出开始反向传播梯度- 利用链式法则计算每个变量的梯度
- 梯度会累积(需要手动清零)
- 非标量需要指定 gradient 参数
- 默认会释放计算图(用
retain_graph=True保留)
现在我们理解了 backward 的流程,但每个操作的梯度是如何定义的?
这就是 Part 4 要解答的...
🔗 Part 4: grad_fn 的追踪:梯度函数链
核心机制
每个操作都有对应的梯度函数,它知道如何计算输入的梯度。
4.1 grad_fn 的本质
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
print(f"y.grad_fn: {y.grad_fn}")
# <PowBackward0 object at 0x...>
# grad_fn 存储了什么?
print(f"grad_fn 的输入: {y.grad_fn.next_functions}")
# ((<AccumulateGrad object at 0x...>, 0),)
# AccumulateGrad 是 x 的梯度累积器
grad_fn 的结构:
class PowBackward0:
def __init__(self, input_tensor, exponent):
self.saved_tensors = (input_tensor,) # 保存前向传播的输入
self.exponent = exponent
self.next_functions = ... # 指向输入的 grad_fn
def apply(self, grad_output):
"""计算输入的梯度"""
input_tensor = self.saved_tensors[0]
# d(x^n)/dx = n * x^(n-1)
grad_input = grad_output * self.exponent * (input_tensor ** (self.exponent - 1))
return grad_input
4.2 常见操作的梯度函数
加法(AddBackward)
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
z = x + y
z.backward()
print(f"dz/dx = {x.grad}") # tensor(1.)
print(f"dz/dy = {y.grad}") # tensor(1.)
# 梯度函数:
# d(x + y)/dx = 1
# d(x + y)/dy = 1
乘法(MulBackward)
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
z = x * y
z.backward()
print(f"dz/dx = {x.grad}") # tensor(3.) = y
print(f"dz/dy = {y.grad}") # tensor(2.) = x
# 梯度函数:
# d(x * y)/dx = y
# d(x * y)/dy = x
矩阵乘法(MmBackward)
A = torch.randn(3, 4, requires_grad=True)
B = torch.randn(4, 5, requires_grad=True)
C = torch.mm(A, B) # C = A @ B
# 假设 dL/dC 已知
grad_C = torch.ones(3, 5)
C.backward(grad_C)
print(f"A.grad.shape: {A.grad.shape}") # (3, 4)
print(f"B.grad.shape: {B.grad.shape}") # (4, 5)
# 梯度函数:
# dL/dA = (dL/dC) @ B^T
# dL/dB = A^T @ (dL/dC)
验证矩阵乘法的梯度:
# C = A @ B
# C[i,j] = sum(A[i,k] * B[k,j] for k)
# dL/dA[i,k] = sum(dL/dC[i,j] * dC/dA[i,k] for j)
# = sum(dL/dC[i,j] * B[k,j] for j)
# = (dL/dC @ B^T)[i,k]
# 同理:dL/dB[k,j] = (A^T @ dL/dC)[k,j]
4.3 ReLU 的梯度
x = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True)
y = torch.relu(x)
y.backward(torch.ones(3))
print(f"x.grad = {x.grad}") # tensor([0., 0., 1.])
# ReLU 的梯度:
# d(ReLU(x))/dx = 1 if x > 0 else 0
问题:0 点的梯度?
x = torch.tensor([0.0], requires_grad=True)
y = torch.relu(x)
y.backward()
print(f"x.grad = {x.grad}") # tensor([0.]) ← PyTorch 选择 0
# 数学上,ReLU 在 0 点不可导
# PyTorch 选择:d(ReLU(0))/dx = 0
4.4 saved_tensors:保存中间结果
class MySigmoid(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
output = 1 / (1 + torch.exp(-x))
ctx.save_for_backward(output) # ← 保存输出,用于反向传播
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors # ← 取出保存的值
# d(sigmoid(x))/dx = sigmoid(x) * (1 - sigmoid(x))
grad_input = grad_output * output * (1 - output)
return grad_input
为什么需要保存?
很多操作的梯度依赖于前向传播的中间结果:
# Sigmoid: d(σ(x))/dx = σ(x) * (1 - σ(x))
# 需要保存 σ(x)
# Softmax: dL/dx[i] 需要整个 softmax 的输出
# BatchNorm: 需要保存均值和方差
4.5 实际例子:Softmax 的梯度
def softmax_with_loss(logits, targets):
"""Softmax + Cross Entropy Loss"""
# 前向传播
exp_logits = torch.exp(logits)
probs = exp_logits / exp_logits.sum(dim=-1, keepdim=True)
# 交叉熵损失
loss = -torch.log(probs[range(len(targets)), targets]).mean()
return loss, probs
# 测试
logits = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True)
targets = torch.tensor([2]) # 正确类别是第 2 类
loss, probs = softmax_with_loss(logits, targets)
loss.backward()
print(f"probs: {probs}")
print(f"grad: {logits.grad}")
# 理论梯度:dL/dx[i] = (prob[i] - 1) if i == target else prob[i]
小结:
- 每个操作都有对应的
grad_fn grad_fn知道如何从输出梯度计算输入梯度- 某些操作需要保存前向传播的中间结果(
saved_tensors) - 常见操作的梯度:加法(1)、乘法(对方)、幂(n*x^(n-1))
如果我们想自定义一个新操作,如何定义它的梯度?
这就是 Part 5 要解答的...
🛠️ Part 5: 自定义操作:实现 autograd.Function
核心能力
当 PyTorch 内置操作不够用时,我们可以自己定义前向和反向传播。
5.1 为什么需要自定义操作?
场景 1:实现新的激活函数
# 例如:Mish 激活函数
# mish(x) = x * tanh(softplus(x))
# = x * tanh(ln(1 + e^x))
场景 2:优化性能
# 将多个操作融合成一个,减少内存访问
# 例如:fused_layer_norm = (x - mean) / std * gamma + beta
场景 3:实现特殊的梯度
# 例如:Straight-Through Estimator(二值化的近似梯度)
# 前向:y = sign(x)
# 反向:dy/dx = 1(假装是恒等函数)
5.2 基本结构
import torch
class MySquare(torch.autograd.Function):
"""自定义平方操作"""
@staticmethod
def forward(ctx, input):
"""
前向传播
Args:
ctx: 上下文对象,用于保存信息
input: 输入 Tensor
Returns:
output: 输出 Tensor
"""
ctx.save_for_backward(input) # 保存输入,用于反向传播
return input ** 2
@staticmethod
def backward(ctx, grad_output):
"""
反向传播
Args:
ctx: 上下文对象
grad_output: 输出的梯度 (dL/dy)
Returns:
grad_input: 输入的梯度 (dL/dx)
"""
input, = ctx.saved_tensors # 取出保存的输入
# d(x^2)/dx = 2x
grad_input = grad_output * 2 * input
return grad_input
# 使用
x = torch.tensor(3.0, requires_grad=True)
y = MySquare.apply(x) # ← 注意:用 .apply() 调用
y.backward()
print(f"x.grad = {x.grad}") # tensor(6.) = 2*3
5.3 多输入多输出
class MyAddMul(torch.autograd.Function):
"""自定义操作:z = (x + y) * w"""
@staticmethod
def forward(ctx, x, y, w):
ctx.save_for_backward(x, y, w)
return (x + y) * w
@staticmethod
def backward(ctx, grad_output):
x, y, w = ctx.saved_tensors
# dz/dx = w
grad_x = grad_output * w
# dz/dy = w
grad_y = grad_output * w
# dz/dw = x + y
grad_w = grad_output * (x + y)
return grad_x, grad_y, grad_w
# 测试
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
w = torch.tensor(4.0, requires_grad=True)
z = MyAddMul.apply(x, y, w)
z.backward()
print(f"dz/dx = {x.grad}") # tensor(4.) = w
print(f"dz/dy = {y.grad}") # tensor(4.) = w
print(f"dz/dw = {w.grad}") # tensor(5.) = x + y
5.4 实际例子:Sigmoid
class MySigmoid(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
output = 1 / (1 + torch.exp(-x))
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
# d(sigmoid(x))/dx = sigmoid(x) * (1 - sigmoid(x))
grad_input = grad_output * output * (1 - output)
return grad_input
# 验证
x = torch.tensor(0.0, requires_grad=True)
y = MySigmoid.apply(x)
y.backward()
print(f"sigmoid(0) = {y}") # tensor(0.5)
print(f"d(sigmoid)/dx = {x.grad}") # tensor(0.25) = 0.5 * 0.5
5.5 梯度检查(Gradient Check)
from torch.autograd import gradcheck
# 定义一个函数
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 3
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
return grad_output * 3 * x ** 2
# 梯度检查
x = torch.randn(10, dtype=torch.double, requires_grad=True)
test = gradcheck(MyFunc.apply, x, eps=1e-6)
print(f"梯度检查: {'通过' if test else '失败'}")
gradcheck 的原理:
# 对于每个输入元素 x[i]:
# 1. 用自动微分计算梯度
# 2. 用数值微分计算梯度
# 3. 比较两者是否接近
# 数值微分:
f(x[i] + eps) - f(x[i] - eps)
─────────────────────────────
2 * eps