如何用两行pytorch代码计算海森向量积(Hessian-vector-product)?

767 阅读2分钟

发布技术文章,文章内首/尾句带关键词“开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 1 天,点击查看活动详情

用途与使用前提

在Pytorch中,有时需要计算模型参数的Hessian矩阵,但是假设模型参数长度为NN,则Hessian矩阵大小为N×NN×N,计算代价与存储代价过高。因此,假如计算Hessian的目的,是计算Hessian矩阵和某一向量的乘积,即:HzH·z,就不需要计算出整个Hessian矩阵。而可以使用一些优化方法,从而避免O(N2)O(N^2)的时间复杂度和空间复杂度快速完成计算。

原理

这种快速计算海森矩阵向量积(hvp)的方法,借助了PyTorch 的自动微分机制:如果我们采用 PyTorch 已经计算出的梯度,将其乘以 z,并对计算结果进行微分。那么就可以得到与直接计算 HzH·z 相同的结果。 所以我们可以在不知道 HH 的元素的情况下计算 HzH·z
如下图所示:

g=dLdθHv=dgTvdθ=dgTdθv+gTdvdθ=dgTdθv\begin{aligned} g & =\frac{d L}{d \theta} \\ H v & =\frac{d g^T v}{d \theta}=\frac{d g^T}{d \theta} v+g^T \frac{d v}{d \theta}=\frac{d g^T}{d \theta} v \end{aligned}

假设gg代表lossloss函数对参数θθ的梯度,则hvp的计算流程为:

  1. 计算出lossloss的梯度gg
  2. 将梯度gg与向量vv做向量乘积
  3. 将计算出的新向量对模型参数再次求导
Hv=dgTdθv=dgTvdθ\begin{aligned} H v & =\frac{d g^T}{d \theta} v=\frac{d g^T v}{d \theta} \end{aligned}

pytorch代码实现

def hvp(y,w,v):
    """
    Arguments:
        y: 标量/tensor,通常来说为loss函数的输出
        w: pytorch tensor的list,代表需要被求解hessian矩阵的参数
        v: pytorch tensor的list,代表需要与hessian矩阵乘积的向量
    Returns:
        return_grads: pytorch tensor的list, 代表hvp的最终结果.
    Raises:
        ValueError: y 与 w 长度不同."""
    if len (w) != len (v):
        raise (ValueError ("w and v must have the same length."))
    for i, v_ele in enumerate(v):
        v[i] = v_ele.cuda()
    # First backprop
    first_grads = grad (y, w, retain_graph=True, create_graph=True)
    # Second backprop
    return_grads = grad (first_grads, w, grad_outputs=v)
    return return_grads

引用

矩阵求导:
justindomke.wordpress.com/2009/01/17/…
文中公式:
github.com/amirgholami…