克服pytorch求导函数的缺陷!如何基于pytorch计算模型参数的Hessian矩阵/二阶导数

4,831 阅读4分钟

起因与问题

上周科研时,需要计算一个CNN参数的Hessian矩阵,即二阶导数。计算Hessian矩阵的实现思路想起来很简单:

  1. 对于模型参数,求两次导。第一次求导时使用loss对于参数求导
  2. 使用计算出来的梯度对于参数再次求导,就可以求出Hessian矩阵/二阶导数。

也就是只需要调用两次autogard.gard()就能够完成计算。此外,pytorch也提供了直接可以用于计算Hessian矩阵的api:autograd.functional.hessian()。但在调pytorch的api去实现的时候,却发现了如下问题:

pytorch的autogard.gard只能求标量对于向量的导数

假设我们的模型参数有N维,则loss函数对于参数的一阶导数为:

lossw=[lossw1lossw2losswn]\frac{\partial \operatorname{loss}}{\partial w} = [\frac{\partial \operatorname{loss}}{\partial w_1}\quad \frac{\partial \operatorname{loss}}{\partial w_2} \ldots \frac{\partial \operatorname{loss}}{\partial w_n}]

而二阶导数为:

H=[2lossw122 loss w1w22 loss w1wn2lossw2w12lossw222 loss w2wn2losswnw12losswnw22losswn2]H=\left[\begin{array}{cccc} \frac{\partial^2 l o s s}{\partial w_1^2} & \frac{\partial^2 \text { loss }}{\partial w_1 \partial w_2} & \ldots & \frac{\partial^2 \text { loss }}{\partial w_1 \partial w_n} \\ \frac{\partial^2 l o s s}{\partial w_2 \partial w_1} & \frac{\partial^2 l o s s}{\partial w_2^2} & \ldots & \frac{\partial^2 \text { loss }}{\partial w_2 \partial w_n} \\ \vdots & \vdots & \vdots & \vdots \\ \frac{\partial^2 l o s s}{\partial w_n \partial w_1} & \frac{\partial^2 \operatorname{loss}}{\partial w_n \partial w_2} & \ldots & \frac{\partial^2 \operatorname{loss}}{\partial w_n^2} \end{array}\right]

通过autogard.gard()可以求出loss的一阶导数,但是二阶导数的计算需要对于一阶导数求导,autogard.gard()无法直接对于向量求导。更确切的说,torch.autograd.grad()函数常用的输入参数有以下三个:

  • outputs (Tensor) – 一个可微函数的输出
  • inputs (Tensor) – 需要被求导的参数
  • grad_outputs (Tensor) – 通常是size与outputs相同的tensor

当output为一个长度大于一的向量时,就需要输入一个grad_outputs向量,其size与outputs相同。该参数相当于和向量做一个点乘,换句话说其将outputs向量转变为一个加权和的形式,从而将向量转化为标量,再进行求导。
因此,如果调用这个函数对梯度进行求导,最终得到一个长度为N的向量,而不会得到NxN大小的矩阵,该向量中每一个元素代表Hessian矩阵的某一行的和。即如下所示:

torch.autograd.functional.hessian无法自动求参数的Hession矩阵

pytorch中提供了一个专门用于计算Hessian矩阵的API:torch.autograd.functional.hessian(),但是在我尝试使用这个API去计算参数的Hession矩阵时却发现了一个问题:这个API无法自动对于参数进行求导。

torch.autograd.functional.hessian的主要参数如下所示:

  • func (function) – 一个可微函数,其输入一个tensor,输出一个标量
  • inputs (tuple of Tensors or Tensor) – 一个tensor或者tensor的tuple,其作为func的输入.

这里可以与gard()的参数进行比较,gard函数输入的是一个标量和一个tensor,其会基于标量的计算过程构建出计算图,计算出与输入的tensor相关的导数,也就是说,这个api是基于计算结果计算某一个参数的导数。但是torch.autograd.functional.hessian其输入的是一个函数和一个函数的输入。而后他会计算出这个函数关于这个输入的Hession矩阵。假设我们的func代表某一个模型,那么Hessian函数计算的则是这个模型输入针对于输入数据的Hession矩阵,而不是关于模型参数的Hessian矩阵。

现有的几种计算Hessian矩阵的实现方法

使用autogard.gard对于一阶梯度循环计算

以下方代码为例:

#定义函数  
x = torch.tensor([0., 0, 0], requires_grad=True)
b = torch.tensor([1., 3, 5])
A = torch.tensor([[-5, -3, -0.5], [-3, -2, 0], [-0.5, 0, -0.5]])
y = b@x + 0.5*x@A@x
 
#计算一阶导数,因为我们需要继续计算二阶导数,所以创建并保留计算图  
grad = torch.autograd.grad(y, x, retain_graph=True, create_graph=True)
#定义Print数组,为输出和进一步利用Hessian矩阵作准备  
Print = torch.tensor([])
for anygrad in grad[0]:  #torch.autograd.grad返回的是元组
    Print = torch.cat((Print, torch.autograd.grad(anygrad, x, retain_graph=True)[0]))
print(Print.view(x.size()[0], -1))

实验方案非常简单,如上所示,假设要对于一个线性模型的参数求二阶导数,首先调用gard求出y对于x的一阶导数,得到一个tensor,而后遍历tensor中的每一个元素,计算该标量对于参数x的导数,并存储结果。就可以得到一个NxN的矩阵,该矩阵即为Hessian矩阵。

调用hessian函数并对model进行封装

如下代码所示:

import torch
import numpy as np
from torch.nn import Module
import torch.nn.functional as F

class Net(Module):
    def __init__(self, h, w):
        super(Net, self).__init__()
        self.c1 = torch.nn.Conv2d(1, 32, 3, 1, 1)
        self.f2 = torch.nn.Linear(32 * h * w, 5)

    def forward(self, x):
        x = self.c1(x)
        x = x.view(x.size(0), -1)
        x = self.f2(x)
        return x

def forward_loss(a, b, c, d):
    p = [a.view(32, 1, 3, 3), b, c.view(5, 32 * 12 * 12), d]
    x = torch.randn(size=[8, 1, 12, 12], dtype=torch.float32)
    y = torch.randint(0, 5, [8])
    x = F.conv2d(x, p[0], p[1], 1, 1)
    x = x.view(x.size(0), -1)
    x = F.linear(x, p[2], p[3])
    loss = F.cross_entropy(x, y)
    return loss

if __name__ == '__main__':
    net = Net(12, 12)
    h = torch.autograd.functional.hessian(forward_loss, tuple([_.view(-1) for _ in net.parameters()]))

与调用gard函数时不同,由于hessian函数的输入必须是一个可调用函数,并且函数的输入与被求导的量一致。由于我们计算的通常是loss对于参数的导数lossw\frac{\partial \operatorname{loss}}{\partial w},而不是输出值对参数的导数yw\frac{\partial \operatorname{y}}{\partial w}因此,在使用此函数计算Hessian矩阵的时候,必须要写好一个前向传播计算loss的函数,并且该函数的输入为模型的参数

此外,loss计算函数在调用时,其输入的参数需要reshape成一维向量并包装成tuple,而后在函数内部再reshape为数组。

参考文献

  1. www.cnblogs.com/chester-cs/…
  2. stackoverflow.com/questions/6…