PyTorch:定义自定义 autograd 函数(Defining New autograd Functions)

0 阅读4分钟

PyTorch:定义自定义 autograd 函数(Defining New autograd Functions)

本示例实现一个三阶多项式模型,通过最小化欧式距离平方和,拟合区间 ([-π, π]) 内的正弦函数 (y=\sin(x))。

与之前示例不同的是,本次将多项式改写为勒让德多项式形式: [ y = a + bP_3(c+dx) ] 其中 (P_3(x) = \frac{1}{2}(5x^3-3x)) 是三阶勒让德多项式

该实现的核心:

  • 基于 PyTorch 张量完成前向传播
  • 自定义 autograd 函数实现勒让德多项式的前向计算和梯度推导
  • 利用 PyTorch autograd 自动完成整体反向传播

数学推导补充:

  • 三阶勒让德多项式:(P_3(x) = \frac{1}{2}(5x^3-3x))
  • 其导数:(P_3'(x) = \frac{3}{2}(5x^2-1))

完整代码(带详细注释)

import torch
import math

class LegendrePolynomial3(torch.autograd.Function):
    """
    自定义 autograd 函数:实现三阶勒让德多项式的前向和反向传播
    需继承 torch.autograd.Function 并实现 forward 和 backward 静态方法
    """

    @staticmethod
    def forward(ctx, input):
        """
        前向传播:接收输入张量,返回输出张量
        ctx:上下文对象,用于保存前向传播的中间结果,供反向传播使用
        - ctx.save_for_backward():保存张量到上下文(仅能保存张量)
        - 也可直接给 ctx 设置属性保存其他类型数据(如 ctx.xxx = xxx)
        """
        ctx.save_for_backward(input)  # 保存输入张量,供反向传播使用
        return 0.5 * (5 * input ** 3 - 3 * input)  # 计算三阶勒让德多项式

    @staticmethod
    def backward(ctx, grad_output):
        """
        反向传播:接收输出的梯度,计算并返回输入的梯度
        ctx:上下文对象,可获取前向传播保存的张量
        grad_output:损失函数对当前函数输出的梯度
        """
        input, = ctx.saved_tensors  # 取出前向传播保存的输入张量
        # 计算输入的梯度:grad_output * P3'(input)
        return grad_output * 1.5 * (5 * input ** 2 - 1)

# 设置数据类型和计算设备
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0")  # 取消注释即可在 GPU 上运行

# 创建输入和输出张量
# 默认 requires_grad=False:无需计算这些张量的梯度
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)  # 拟合正弦函数

# 创建权重张量(y = a + b*P3(c + d*x),共4个参数)
# 初始化值接近真实解,确保模型收敛
# requires_grad=True:需要自动计算这些参数的梯度
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)

learning_rate = 5e-6  # 学习率
for t in range(2000):  # 迭代2000次
    # 应用自定义 autograd 函数:使用 apply 方法,别名 P3 简化调用
    P3 = LegendrePolynomial3.apply

    # 前向传播:计算预测值y,使用自定义的 P3 函数
    y_pred = a + b * P3(c + d * x)

    # 计算并打印损失(每100次迭代打印一次)
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())

    # 自动微分:执行反向传播,计算所有可导参数的梯度
    loss.backward()

    # 梯度下降更新权重
    with torch.no_grad():
        a -= learning_rate * a.grad
        b -= learning_rate * b.grad
        c -= learning_rate * c.grad
        d -= learning_rate * d.grad

        # 手动清零梯度(避免梯度累积)
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None

# 打印最终拟合结果
print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')

典型运行输出

99 212.92306518554688
199 127.0043182373047
299 76.61222839355469
399 46.92702102661133
499 29.35991859436035
599 18.689119338989258
699 12.152021408081055
799 8.103891372680664
899 5.6281156539917
999 4.113130569458008
1099 3.1941769123077393
1199 2.622696876525879
1299 2.268768072128296
1399 2.048941135406494
1499 1.9119837284088135
1599 1.8257107734680176
1699 1.7701241970062256
1799 1.733727216720581
1899 1.709191083908081
1999 1.6917126178741455
Result: y = -0.00030826766867220402 + -0.9957594871520996 * P3(-0.0012611821293830872 + 0.29918044805526733 x)

核心知识点解析

关键操作作用说明
torch.autograd.Function自定义 autograd 函数的基类
forward(ctx, input)实现前向传播逻辑,ctx 用于保存中间结果
backward(ctx, grad_output)实现反向传播逻辑,计算输入的梯度
ctx.save_for_backward()保存张量到上下文,供反向传播使用
Function.apply调用自定义 autograd 函数的唯一方式
ctx.saved_tensors获取前向传播保存的张量

总结

  1. 自定义 autograd 函数需继承 torch.autograd.Function,并实现 forwardbackward 静态方法,分别定义前向计算和梯度推导逻辑;
  2. ctx 上下文对象是前向和反向传播的桥梁,可通过 save_for_backward 保存张量,供反向传播使用;
  3. 自定义 autograd 函数通过 apply 方法调用,可无缝集成到 PyTorch 的自动微分体系中,无需修改原有反向传播逻辑;
  4. 该方式适用于原生算子无法满足需求的场景(如特殊数学函数、自定义运算),需手动推导梯度公式以实现 backward 方法。