PyTorch学习阶段二 - Autograd 自动微分引擎

11 阅读14分钟

PyTorch 架构级学习系列 - 第 3 篇

本文将揭开 PyTorch 自动微分的神秘面纱。你将理解计算图是如何动态构建的,backward() 背后发生了什么,以及如何从零实现一个支持自动微分的系统。


📚 目录

  1. 从链式法则到自动微分:为什么需要 Autograd?
  2. 计算图的动态构建:每个操作都在记录历史
  3. backward() 的实现原理:反向传播的机制
  4. grad_fn 的追踪:梯度函数链
  5. 自定义操作:实现 autograd.Function
  6. 手写微型 Autograd:从零构建自动微分系统
  7. 实战:训练神经网络的完整流程

🎯 Part 1: 从链式法则到自动微分

💡 核心问题

训练神经网络需要计算梯度,手动求导太痛苦。能否让计算机自动计算梯度?

1.1 问题:手动求导的痛苦

假设我们有一个简单的函数:

import torch

# 函数:y = (x^2 + 3) * (2x + 1)
def f(x):
    a = x ** 2
    b = a + 3
    c = 2 * x
    d = c + 1
    y = b * d
    return y

x = torch.tensor(2.0)
y = f(x)
print(f"y = {y}")  # y = 35.0

问题:如何计算 dy/dx?

方法 1:手动求导(符号微分)

# 手动推导:
# y = (x^2 + 3) * (2x + 1)
# dy/dx = d/dx[(x^2 + 3) * (2x + 1)]
#       = (x^2 + 3) * 2 + 2x * (2x + 1)  ← 乘法法则
#       = 2x^2 + 6 + 4x^2 + 2x
#       = 6x^2 + 2x + 6

def df_dx(x):
    return 6 * x**2 + 2 * x + 6

x = 2.0
grad = df_dx(x)
print(f"dy/dx = {grad}")  # dy/dx = 34.0

问题:

  • ❌ 每个新函数都要手动推导公式
  • ❌ 复杂函数的导数公式极其复杂
  • ❌ 神经网络有数百万个参数,无法手动求导

方法 2:数值微分

# 利用导数定义:f'(x) = lim[h→0] (f(x+h) - f(x)) / h
def numerical_gradient(f, x, h=1e-5):
    return (f(x + h) - f(x)) / h

x = 2.0
grad = numerical_gradient(f, x)
print(f"dy/dx ≈ {grad}")  # dy/dx ≈ 34.00001

问题:

  • ❌ 精度有限(取决于 h 的大小)
  • ❌ 对于 n 个参数,需要计算 n+1 次前向传播
  • ❌ 神经网络有数百万参数,计算量太大

方法 3:自动微分(PyTorch 的解决方案)

import torch

x = torch.tensor(2.0, requires_grad=True)  # ← 告诉 PyTorch 追踪这个变量

# 前向传播(PyTorch 自动记录计算历史)
a = x ** 2
b = a + 3
c = 2 * x
d = c + 1
y = b * d

# 反向传播(PyTorch 自动计算梯度)
y.backward()

print(f"dy/dx = {x.grad}")  # dy/dx = 34.0 ← 完全正确!

优点:

  • ✅ 精确(不是近似)
  • ✅ 高效(只需一次反向传播)
  • ✅ 自动(不需要手动推导)

1.2 链式法则:自动微分的数学基础

链式法则(Chain Rule): 复合函数的导数

如果 y = f(g(x)),那么:
dy/dx = df/dg * dg/dx

多变量的链式法则:

如果 z 依赖于多个中间变量:
z = f(x, y)
x = g(u)
y = h(u)

那么:
dz/du = (∂z/∂x) * (dx/du) + (∂z/∂y) * (dy/du)

例子:理解链式法则

# 函数:z = (x + y)^2,其中 x = 2u, y = 3u

# 前向传播
u = 1.0
x = 2 * u      # x = 2
y = 3 * u      # y = 3
w = x + y      # w = 5
z = w ** 2     # z = 25

# 反向传播(手动计算梯度)
# dz/du = ?

# 步骤 1:z 对 w 的偏导
dz_dw = 2 * w  # dz/dw = 2w = 10

# 步骤 2:w 对 x 和 y 的偏导
dw_dx = 1      # ∂w/∂x = 1
dw_dy = 1      # ∂w/∂y = 1

# 步骤 3:x 和 y 对 u 的导数
dx_du = 2      # dx/du = 2
dy_du = 3      # dy/du = 3

# 步骤 4:应用链式法则
dz_dx = dz_dw * dw_dx  # = 10 * 1 = 10
dz_dy = dz_dw * dw_dy  # = 10 * 1 = 10
dz_du = dz_dx * dx_du + dz_dy * dy_du  # = 10*2 + 10*3 = 50

print(f"dz/du = {dz_du}")  # 50

验证(用 PyTorch):

u = torch.tensor(1.0, requires_grad=True)
x = 2 * u
y = 3 * u
w = x + y
z = w ** 2

z.backward()
print(f"dz/du = {u.grad}")  # tensor(50.) ← 完全一致!

1.3 计算图:可视化计算过程

计算图(Computational Graph): 用图的方式表示计算过程

u = torch.tensor(1.0, requires_grad=True)
x = 2 * u
y = 3 * u
w = x + y
z = w ** 2

前向传播的计算图:

        u (1.0)
       / \
      /   \
   *2      *3
    /       \
   x (2.0)  y (3.0)
    \       /
     \     /
      \   /
        +
        |
      w (5.0)
        |
       ^2
        |
      z (25.0)

反向传播的计算图:

        u
       / \
      /   \
   ∂z/∂x  ∂z/∂y
    =10    =10
     \      /
      \    /
       \  /
     dx/du=2  dy/du=3
        |        |
        +--------+
             |
          dz/du=50

小结:

  • 手动求导:不可行
  • 数值微分:太慢
  • 自动微分:利用链式法则,自动计算梯度
  • 计算图:可视化表示计算和梯度传播

但这引出了新问题:PyTorch 如何知道计算的历史?如何构建这个图?

这就是 Part 2 要解答的...


🔨 Part 2: 计算图的动态构建

核心洞察

PyTorch 在前向传播时动态构建计算图,每个操作都会记录历史。

2.1 requires_grad:追踪的开关

import torch

# 不追踪梯度
x = torch.tensor(2.0)
y = x ** 2
print(f"x.requires_grad: {x.requires_grad}")  # False
print(f"y.requires_grad: {y.requires_grad}")  # False
print(f"y.grad_fn: {y.grad_fn}")              # None ← 没有梯度函数

# 追踪梯度
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
print(f"x.requires_grad: {x.requires_grad}")  # True
print(f"y.requires_grad: {y.requires_grad}")  # True ← 自动继承
print(f"y.grad_fn: {y.grad_fn}")              # <PowBackward0> ← 有梯度函数!

关键发现:

  • requires_grad=True 告诉 PyTorch 追踪这个 Tensor
  • 所有依赖它的 Tensor 自动继承 requires_grad=True
  • 每个操作都会创建一个 grad_fn(梯度函数)

2.2 grad_fn:操作的记录

x = torch.tensor(2.0, requires_grad=True)
a = x ** 2       # a.grad_fn = <PowBackward0>
b = a + 3        # b.grad_fn = <AddBackward0>
c = 2 * x        # c.grad_fn = <MulBackward0>
d = c + 1        # d.grad_fn = <AddBackward0>
y = b * d        # y.grad_fn = <MulBackward0>

print(f"a.grad_fn: {a.grad_fn}")  # <PowBackward0>
print(f"b.grad_fn: {b.grad_fn}")  # <AddBackward0>
print(f"y.grad_fn: {y.grad_fn}")  # <MulBackward0>

grad_fn 存储了什么?

# grad_fn 存储:
# 1. 操作类型(如 Mul、Add、Pow)
# 2. 输入 Tensor 的引用
# 3. 如何计算梯度的信息

print(f"y 的梯度函数: {y.grad_fn}")
print(f"y 的输入: {y.grad_fn.next_functions}")
# ((<AddBackward0 object>, 0), (<AddBackward0 object>, 0))
# 说明 y = b * d,b 和 d 都是通过 Add 操作得到的

2.3 计算图的构建过程

让我们详细看看每一步发生了什么:

import torch

x = torch.tensor(2.0, requires_grad=True)
print("=== 初始状态 ===")
print(f"x: value={x.item()}, grad_fn={x.grad_fn}")
# x 是叶子节点,没有 grad_fn

print("\n=== 步骤 1: a = x ** 2 ===")
a = x ** 2
print(f"a: value={a.item()}, grad_fn={a.grad_fn}")
# a: value=4.0, grad_fn=<PowBackward0>
# PowBackward0 记录了:
#   - 操作:power
#   - 输入:x
#   - 指数:2

print("\n=== 步骤 2: b = a + 3 ===")
b = a + 3
print(f"b: value={b.item()}, grad_fn={b.grad_fn}")
# b: value=7.0, grad_fn=<AddBackward0>
# AddBackward0 记录了:
#   - 操作:add
#   - 输入:a 和常数 3

print("\n=== 步骤 3: y = b * 2 ===")
y = b * 2
print(f"y: value={y.item()}, grad_fn={y.grad_fn}")
# y: value=14.0, grad_fn=<MulBackward0>
# MulBackward0 记录了:
#   - 操作:multiply
#   - 输入:b 和常数 2

计算图的结构:

x (leaf)
  |
  | PowBackward0
  |   (input: x, exponent: 2)
  ↓
a = x^2
  |
  | AddBackward0
  |   (input: a, other: 3)
  ↓
b = a + 3
  |
  | MulBackward0
  |   (input: b, other: 2)
  ↓
y = b * 2

2.4 动态图 vs 静态图

PyTorch(动态图):

# 每次执行都构建新的图
for i in range(3):
    x = torch.tensor(float(i), requires_grad=True)
    if i < 2:
        y = x ** 2
    else:
        y = x ** 3  # ← 不同的操作,不同的图
    y.backward()
    print(f"i={i}, grad={x.grad}")

TensorFlow 1.x(静态图):

# 伪代码:先定义图,再执行
import tensorflow as tf

# 定义阶段:构建图(只执行一次)
x = tf.placeholder(tf.float32)
y = x ** 2
grad = tf.gradients(y, x)

# 执行阶段:运行图(多次执行)
with tf.Session() as sess:
    for i in range(3):
        result = sess.run([y, grad], feed_dict={x: float(i)})

对比:

特性PyTorch (动态图)TensorFlow 1.x (静态图)
图的构建每次运行时构建预先定义
调试容易(Python 调试器)困难(图内部)
灵活性高(可以用 if/for)低(需要特殊操作)
性能稍慢(重复构建)快(优化过的图)

2.5 detach():切断梯度传播

x = torch.tensor(2.0, requires_grad=True)
a = x ** 2
b = a.detach()  # ← 从计算图中分离
c = b + 3

print(f"a.requires_grad: {a.requires_grad}")  # True
print(f"b.requires_grad: {b.requires_grad}")  # False ← 被切断了
print(f"c.requires_grad: {c.requires_grad}")  # False

# b 和 c 不在计算图中
print(f"a.grad_fn: {a.grad_fn}")  # <PowBackward0>
print(f"b.grad_fn: {b.grad_fn}")  # None ← 没有梯度函数
print(f"c.grad_fn: {c.grad_fn}")  # None

使用场景:

# 场景 1:固定部分网络
pretrained_model = load_pretrained_model()
x = pretrained_model(input)
x = x.detach()  # 不更新 pretrained_model 的参数

# 场景 2:避免内存泄漏
for epoch in range(100):
    loss = compute_loss()
    loss.backward()

    # 记录损失(不需要梯度)
    loss_value = loss.detach().item()  # 或 loss.item()
    losses.append(loss_value)

2.6 查看完整的计算图

import torch

x = torch.tensor(2.0, requires_grad=True)
a = x ** 2
b = a + 3
y = b * 2

def print_graph(tensor, indent=0):
    """递归打印计算图"""
    print("  " * indent + f"{tensor.grad_fn}")
    if tensor.grad_fn is not None:
        for next_fn, _ in tensor.grad_fn.next_functions:
            if next_fn is not None:
                # 创建一个假的 tensor 来继续遍历
                print_graph_from_fn(next_fn, indent + 1)

def print_graph_from_fn(grad_fn, indent=0):
    print("  " * indent + f"{grad_fn}")
    if hasattr(grad_fn, 'next_functions'):
        for next_fn, _ in grad_fn.next_functions:
            if next_fn is not None:
                print_graph_from_fn(next_fn, indent + 1)

print("=== 计算图结构 ===")
print_graph(y)

# 输出:
# <MulBackward0>
#   <AddBackward0>
#     <PowBackward0>
#       <AccumulateGrad>  ← x 的梯度累积器

小结:

  • requires_grad=True 开启追踪
  • 每个操作创建一个 grad_fn
  • grad_fn 记录操作类型和输入
  • 计算图是动态构建的(每次运行都重新构建)
  • detach() 可以切断梯度传播

现在我们有了计算图,如何利用它计算梯度?

这就是 Part 3 要解答的...


⚙️ Part 3: backward() 的实现原理

核心机制

backward() 从输出开始,沿着计算图反向传播梯度。

3.1 简单例子:理解 backward()

import torch

x = torch.tensor(3.0, requires_grad=True)
y = x ** 2

print("=== 前向传播完成 ===")
print(f"x = {x}")
print(f"y = {y}")
print(f"x.grad = {x.grad}")  # None(还没计算梯度)

print("\n=== 调用 backward ===")
y.backward()

print(f"x.grad = {x.grad}")  # tensor(6.) ← dy/dx = 2x = 6

backward() 做了什么?

  1. y 开始(dy/dy = 1)
  2. 调用 y.grad_fn.backward(1.0)
  3. PowBackward0 计算 dy/dx = 2x = 6
  4. 将梯度累积到 x.grad

3.2 链式求导:多个操作

x = torch.tensor(2.0, requires_grad=True)
a = x ** 2      # a = 4
b = a + 3       # b = 7
y = b * 2       # y = 14

y.backward()
print(f"dy/dx = {x.grad}")  # tensor(8.)

手动验证:

# y = (x^2 + 3) * 2
# dy/dx = 2 * 2x = 4x
# 当 x=2 时,dy/dx = 8 ✓

反向传播的详细步骤:

# 步骤 1:初始化
# dy/dy = 1.0

# 步骤 2:y = b * 2
# dy/db = ∂(b * 2)/∂b = 2
# 传播:dy/db = dy/dy * 2 = 1.0 * 2 = 2.0

# 步骤 3:b = a + 3
# db/da = ∂(a + 3)/∂a = 1
# 传播:dy/da = dy/db * 1 = 2.0 * 1 = 2.0

# 步骤 4:a = x^2
# da/dx = ∂(x^2)/∂x = 2x = 4
# 传播:dy/dx = dy/da * 2x = 2.0 * 4 = 8.0 ✓

3.3 多输入的反向传播

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

# z = x * y + x^2
z = x * y + x ** 2

z.backward()

print(f"dz/dx = {x.grad}")  # tensor(7.) = y + 2x = 3 + 4
print(f"dz/dy = {y.grad}")  # tensor(2.) = x = 2

计算图:

    x          y
   / \         |
  /   \        |
 ^2    \       |
  |     \      |
  |      \     |
  |       \    |
  +        *   /
   \       |  /
    \      | /
     \     |/
      \   Add
       \  /
        \/
         z

反向传播:

初始:dz/dz = 1

步骤 1:z = add(x^2, x*y)
  dz/d(x^2) = 1
  dz/d(x*y) = 1

步骤 2:x*y
  d(x*y)/dx = y = 3
  d(x*y)/dy = x = 2
  传播:dz/dx += 1 * 3 = 3
       dz/dy = 1 * 2 = 2

步骤 3:x^2
  d(x^2)/dx = 2x = 4
  传播:dz/dx += 1 * 4 = 4

最终:dz/dx = 3 + 4 = 7 ✓
     dz/dy = 2

3.4 梯度累积

x = torch.tensor(2.0, requires_grad=True)

# 第一次计算
y1 = x ** 2
y1.backward()
print(f"第一次: x.grad = {x.grad}")  # tensor(4.)

# 第二次计算(不清零)
y2 = x ** 3
y2.backward()
print(f"第二次: x.grad = {x.grad}")  # tensor(16.) = 4 + 12 ← 累积了!

# 正确做法:每次 backward 前清零
x.grad.zero_()
y2.backward()
print(f"清零后: x.grad = {x.grad}")  # tensor(12.)

为什么会累积?

# 因为 backward() 的实现是:
x.grad = x.grad + new_grad  # 累加,不是赋值

# 这样设计的原因:
# 1. 支持多个损失函数
loss1.backward()  # 计算第一个损失的梯度
loss2.backward()  # 累加第二个损失的梯度
# total_grad = grad1 + grad2

# 2. 支持梯度检查点(gradient checkpointing)

3.5 非标量的 backward

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2  # y = [1, 4, 9]

# 错误:
# y.backward()  # RuntimeError: grad can be implicitly created only for scalar outputs

# 正确:指定 gradient 参数
y.backward(torch.tensor([1.0, 1.0, 1.0]))

print(f"x.grad = {x.grad}")  # tensor([2., 4., 6.]) = 2*x

为什么需要 gradient 参数?

# backward(gradient) 的含义:
# 假设有一个标量 L 依赖于 y
# gradient = dL/dy
# backward() 计算 dL/dx = (dL/dy) * (dy/dx)

# 例子:
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2

# 假设 L = sum(y) = y[0] + y[1] + y[2]
# dL/dy = [1, 1, 1]
y.backward(torch.tensor([1.0, 1.0, 1.0]))
print(x.grad)  # [2., 4., 6.]

# 假设 L = 2*y[0] + 3*y[1] + 4*y[2]
# dL/dy = [2, 3, 4]
x.grad.zero_()
y.backward(torch.tensor([2.0, 3.0, 4.0]))
print(x.grad)  # [4., 12., 24.] = [2*2*1, 3*2*2, 4*2*3]

3.6 retain_graph:保留计算图

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2

# 第一次 backward
y.backward(retain_graph=True)  # ← 保留图
print(f"第一次: x.grad = {x.grad}")

# 第二次 backward(需要先清零)
x.grad.zero_()
y.backward()  # 可以再次调用
print(f"第二次: x.grad = {x.grad}")

# 如果不用 retain_graph:
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()
# y.backward()  # RuntimeError: Trying to backward through the graph a second time

为什么默认不保留图?

  • 节省内存:计算图可能很大
  • 大多数情况只需要一次 backward

小结:

  • backward() 从输出开始反向传播梯度
  • 利用链式法则计算每个变量的梯度
  • 梯度会累积(需要手动清零)
  • 非标量需要指定 gradient 参数
  • 默认会释放计算图(用 retain_graph=True 保留)

现在我们理解了 backward 的流程,但每个操作的梯度是如何定义的?

这就是 Part 4 要解答的...


🔗 Part 4: grad_fn 的追踪:梯度函数链

核心机制

每个操作都有对应的梯度函数,它知道如何计算输入的梯度。

4.1 grad_fn 的本质

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2

print(f"y.grad_fn: {y.grad_fn}")
# <PowBackward0 object at 0x...>

# grad_fn 存储了什么?
print(f"grad_fn 的输入: {y.grad_fn.next_functions}")
# ((<AccumulateGrad object at 0x...>, 0),)
# AccumulateGrad 是 x 的梯度累积器

grad_fn 的结构:

class PowBackward0:
    def __init__(self, input_tensor, exponent):
        self.saved_tensors = (input_tensor,)  # 保存前向传播的输入
        self.exponent = exponent
        self.next_functions = ...  # 指向输入的 grad_fn

    def apply(self, grad_output):
        """计算输入的梯度"""
        input_tensor = self.saved_tensors[0]
        # d(x^n)/dx = n * x^(n-1)
        grad_input = grad_output * self.exponent * (input_tensor ** (self.exponent - 1))
        return grad_input

4.2 常见操作的梯度函数

加法(AddBackward)

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
z = x + y

z.backward()
print(f"dz/dx = {x.grad}")  # tensor(1.)
print(f"dz/dy = {y.grad}")  # tensor(1.)

# 梯度函数:
# d(x + y)/dx = 1
# d(x + y)/dy = 1

乘法(MulBackward)

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
z = x * y

z.backward()
print(f"dz/dx = {x.grad}")  # tensor(3.) = y
print(f"dz/dy = {y.grad}")  # tensor(2.) = x

# 梯度函数:
# d(x * y)/dx = y
# d(x * y)/dy = x

矩阵乘法(MmBackward)

A = torch.randn(3, 4, requires_grad=True)
B = torch.randn(4, 5, requires_grad=True)
C = torch.mm(A, B)  # C = A @ B

# 假设 dL/dC 已知
grad_C = torch.ones(3, 5)
C.backward(grad_C)

print(f"A.grad.shape: {A.grad.shape}")  # (3, 4)
print(f"B.grad.shape: {B.grad.shape}")  # (4, 5)

# 梯度函数:
# dL/dA = (dL/dC) @ B^T
# dL/dB = A^T @ (dL/dC)

验证矩阵乘法的梯度:

# C = A @ B
# C[i,j] = sum(A[i,k] * B[k,j] for k)

# dL/dA[i,k] = sum(dL/dC[i,j] * dC/dA[i,k] for j)
#            = sum(dL/dC[i,j] * B[k,j] for j)
#            = (dL/dC @ B^T)[i,k]

# 同理:dL/dB[k,j] = (A^T @ dL/dC)[k,j]

4.3 ReLU 的梯度

x = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True)
y = torch.relu(x)

y.backward(torch.ones(3))
print(f"x.grad = {x.grad}")  # tensor([0., 0., 1.])

# ReLU 的梯度:
# d(ReLU(x))/dx = 1 if x > 0 else 0

问题:0 点的梯度?

x = torch.tensor([0.0], requires_grad=True)
y = torch.relu(x)
y.backward()
print(f"x.grad = {x.grad}")  # tensor([0.]) ← PyTorch 选择 0

# 数学上,ReLU 在 0 点不可导
# PyTorch 选择:d(ReLU(0))/dx = 0

4.4 saved_tensors:保存中间结果

class MySigmoid(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        output = 1 / (1 + torch.exp(-x))
        ctx.save_for_backward(output)  # ← 保存输出,用于反向传播
        return output

    @staticmethod
    def backward(ctx, grad_output):
        output, = ctx.saved_tensors  # ← 取出保存的值
        # d(sigmoid(x))/dx = sigmoid(x) * (1 - sigmoid(x))
        grad_input = grad_output * output * (1 - output)
        return grad_input

为什么需要保存?

很多操作的梯度依赖于前向传播的中间结果:

# Sigmoid: d(σ(x))/dx = σ(x) * (1 - σ(x))
#          需要保存 σ(x)

# Softmax: dL/dx[i] 需要整个 softmax 的输出

# BatchNorm: 需要保存均值和方差

4.5 实际例子:Softmax 的梯度

def softmax_with_loss(logits, targets):
    """Softmax + Cross Entropy Loss"""
    # 前向传播
    exp_logits = torch.exp(logits)
    probs = exp_logits / exp_logits.sum(dim=-1, keepdim=True)

    # 交叉熵损失
    loss = -torch.log(probs[range(len(targets)), targets]).mean()

    return loss, probs

# 测试
logits = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True)
targets = torch.tensor([2])  # 正确类别是第 2 类

loss, probs = softmax_with_loss(logits, targets)
loss.backward()

print(f"probs: {probs}")
print(f"grad: {logits.grad}")

# 理论梯度:dL/dx[i] = (prob[i] - 1) if i == target else prob[i]

小结:

  • 每个操作都有对应的 grad_fn
  • grad_fn 知道如何从输出梯度计算输入梯度
  • 某些操作需要保存前向传播的中间结果(saved_tensors
  • 常见操作的梯度:加法(1)、乘法(对方)、幂(n*x^(n-1))

如果我们想自定义一个新操作,如何定义它的梯度?

这就是 Part 5 要解答的...


🛠️ Part 5: 自定义操作:实现 autograd.Function

核心能力

当 PyTorch 内置操作不够用时,我们可以自己定义前向和反向传播。

5.1 为什么需要自定义操作?

场景 1:实现新的激活函数

# 例如:Mish 激活函数
# mish(x) = x * tanh(softplus(x))
#         = x * tanh(ln(1 + e^x))

场景 2:优化性能

# 将多个操作融合成一个,减少内存访问
# 例如:fused_layer_norm = (x - mean) / std * gamma + beta

场景 3:实现特殊的梯度

# 例如:Straight-Through Estimator(二值化的近似梯度)
# 前向:y = sign(x)
# 反向:dy/dx = 1(假装是恒等函数)

5.2 基本结构

import torch

class MySquare(torch.autograd.Function):
    """自定义平方操作"""

    @staticmethod
    def forward(ctx, input):
        """
        前向传播
        Args:
            ctx: 上下文对象,用于保存信息
            input: 输入 Tensor
        Returns:
            output: 输出 Tensor
        """
        ctx.save_for_backward(input)  # 保存输入,用于反向传播
        return input ** 2

    @staticmethod
    def backward(ctx, grad_output):
        """
        反向传播
        Args:
            ctx: 上下文对象
            grad_output: 输出的梯度 (dL/dy)
        Returns:
            grad_input: 输入的梯度 (dL/dx)
        """
        input, = ctx.saved_tensors  # 取出保存的输入
        # d(x^2)/dx = 2x
        grad_input = grad_output * 2 * input
        return grad_input

# 使用
x = torch.tensor(3.0, requires_grad=True)
y = MySquare.apply(x)  # ← 注意:用 .apply() 调用

y.backward()
print(f"x.grad = {x.grad}")  # tensor(6.) = 2*3

5.3 多输入多输出

class MyAddMul(torch.autograd.Function):
    """自定义操作:z = (x + y) * w"""

    @staticmethod
    def forward(ctx, x, y, w):
        ctx.save_for_backward(x, y, w)
        return (x + y) * w

    @staticmethod
    def backward(ctx, grad_output):
        x, y, w = ctx.saved_tensors

        # dz/dx = w
        grad_x = grad_output * w

        # dz/dy = w
        grad_y = grad_output * w

        # dz/dw = x + y
        grad_w = grad_output * (x + y)

        return grad_x, grad_y, grad_w

# 测试
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
w = torch.tensor(4.0, requires_grad=True)

z = MyAddMul.apply(x, y, w)
z.backward()

print(f"dz/dx = {x.grad}")  # tensor(4.) = w
print(f"dz/dy = {y.grad}")  # tensor(4.) = w
print(f"dz/dw = {w.grad}")  # tensor(5.) = x + y

5.4 实际例子:Sigmoid

class MySigmoid(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        output = 1 / (1 + torch.exp(-x))
        ctx.save_for_backward(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        output, = ctx.saved_tensors
        # d(sigmoid(x))/dx = sigmoid(x) * (1 - sigmoid(x))
        grad_input = grad_output * output * (1 - output)
        return grad_input

# 验证
x = torch.tensor(0.0, requires_grad=True)
y = MySigmoid.apply(x)

y.backward()
print(f"sigmoid(0) = {y}")      # tensor(0.5)
print(f"d(sigmoid)/dx = {x.grad}")  # tensor(0.25) = 0.5 * 0.5

5.5 梯度检查(Gradient Check)

from torch.autograd import gradcheck

# 定义一个函数
class MyFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x ** 3

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        return grad_output * 3 * x ** 2

# 梯度检查
x = torch.randn(10, dtype=torch.double, requires_grad=True)
test = gradcheck(MyFunc.apply, x, eps=1e-6)
print(f"梯度检查: {'通过' if test else '失败'}")

gradcheck 的原理:

# 对于每个输入元素 x[i]:
# 1. 用自动微分计算梯度
# 2. 用数值微分计算梯度
# 3. 比较两者是否接近

# 数值微分:
f(x[i] + eps) - f(x[i] - eps)
─────────────────────────────
          2 * eps