海森矩阵的三种计算方法及其Pytorch实现

382 阅读6分钟

问题的缘起

在深入研究神经网络二阶优化算法的过程中,我遭遇了一个计算上的重大挑战:海塞矩阵-向量乘积(Hessian-Vector Product, HVP)的高效求解。这一计算量在现代优化理论中占据着举足轻重的地位,无论是经典的牛顿法、共轭梯度算法,还是当下备受关注的双层优化(bilevel optimization)技术,都对其有着强烈的依赖。

最初的解决思路显得相当直观:构建完整的海塞矩阵,随后执行矩阵-向量乘法运算。然而,现实很快给出了残酷的回应:

存储复杂度的指数级爆炸

考虑一个包含d个可训练参数的神经网络架构,其对应的海塞矩阵规模为d×d,存储需求呈O(d²)增长。当面对现代大规模深度学习模型时,这种需求变得完全不可行。以GPT-3的1750亿参数为例进行估算:完整海塞矩阵的存储将消耗约30万TB的空间资源!

更为关键的是,在绝大多数应用场景中,我们的真实需求并非获得完整的海塞矩阵,而仅仅是其与特定向量的乘积结果。这种做法类似于为了完成简单的矩阵乘法而预先展开所有参与运算的矩阵——既造成空间浪费,又带来时间损耗。

传统方法的技术壁垒

PyTorch框架提供的autograd.functional.hessian()虽然具备海塞矩阵计算能力,但在实际应用中暴露出两个致命缺陷:

  1. 对模型参数的海塞矩阵无法直接求解,仅支持输入数据层面的计算
  2. 即使勉强可用,仍然无法摆脱O(d²)存储复杂度的束缚

那么,是否存在绕过显式海塞矩阵构造而直接获得HVP的计算路径?答案是肯定的,关键在于对自动微分技术的创新性运用。

理论突破:Pearlmutter的核心洞察

1994年,Pearlmutter提出了一个看似简单却具有深远影响的数学观察:

海塞矩阵-向量乘积本质上等价于梯度的方向导数

其数学表述为:

∇²f(θ)v = ∇θ(∇f(θ) · v)

这一等价关系揭示了HVP可以通过计算梯度在指定方向v上的变化率来获得,而这恰恰是自动微分框架的核心优势所在。

基于这一理论基础,我们可以设计出三种截然不同的计算策略,它们通过巧妙组合前向模式与反向模式自动微分来实现目标:

三大计算策略的技术剖析

策略一:前向-反向混合模式 (Forward-over-Reverse)

核心理念:首先运用反向模式构建梯度函数,继而通过前向模式计算该函数的方向导数。

执行序列

  1. 梯度函数构建:grad_f = ∇f(θ)
  2. 雅可比-向量乘积计算:∇²f(θ)v = JVP(grad_f, v)

JAX框架实现

def hvp_forward_over_reverse_jax(f, params, v):
    return jax.jvp(jax.grad(f), (params,), (v,))[1]

# 应用示例
def loss_fn(params):
    # 这里定义你的损失函数
    return 0.5 * jnp.sum(params**2)

params = jnp.array([1.0, 2.0, 3.0])
v = jnp.array([0.1, 0.2, 0.3])
hvp_result = hvp_forward_over_reverse_jax(loss_fn, params, v)

PyTorch框架实现

def hvp_forward_over_reverse_torch(f, params, v):
    grad_fun = torch.func.grad(f)
    return torch.func.jvp(grad_fun, (params,), (v,))[1]

# 应用示例
def loss_fn(params):
    return 0.5 * torch.sum(params**2)

params = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
v = torch.tensor([0.1, 0.2, 0.3])
hvp_result = hvp_forward_over_reverse_torch(loss_fn, params, v)

技术优势

  • 能够在单次计算中同时获得梯度与HVP,在双层优化场景中具有显著价值
  • 内存使用效率相对优异
  • 概念框架直观,便于理解和调试

策略二:反向-前向混合模式 (Reverse-over-Forward)

核心理念:先通过前向模式计算方向导数,随后运用反向模式对该方向导数求梯度。

执行序列

  1. 方向导数计算:directional_deriv = ∇f(θ) · v
  2. 梯度求解:∇²f(θ)v = ∇θ(directional_deriv)

JAX框架实现

def hvp_reverse_over_forward_jax(f, params, v):
    jvp_fun = lambda params: jax.jvp(f, (params,), (v,))[1]
    return jax.grad(jvp_fun)(params)

PyTorch框架实现

def hvp_reverse_over_forward_torch(f, params, v):
    def jvp_fun(params):
        return torch.func.jvp(f, (params,), (v,))[1]
    return torch.func.grad(jvp_fun)(params)

技术优势

  • 仅需单次反向传播过程
  • 内存使用效率达到最优水平
  • 计算复杂度与前向-反向模式持平

策略三:反向-反向双重模式 (Reverse-over-Reverse)

核心理念:先计算梯度与向量的内积以得到标量,然后对该标量执行梯度求解。

执行序列

  1. 标量化处理:scalar = ∇f(θ) · v
  2. 二次梯度计算:∇²f(θ)v = ∇θ(scalar)

JAX框架实现

def hvp_reverse_over_reverse_jax(f, params, v):
    return jax.grad(lambda params: jnp.vdot(jax.grad(f)(params), v))(params)

PyTorch框架实现

def hvp_reverse_over_reverse_torch(f, params, v):
    grad_fun = torch.func.grad(f)
    return torch.func.grad(
        lambda params: torch.sum(grad_fun(params) * v)
    )(params)

技术局限

  • 需要执行两次反向传播,计算开销较大
  • 内存占用达到最高水平
  • 计算图复杂度显著增加

基于Pytroch的实现

接下来,我们将探讨这些方法在实际深度学习模型中的应用实现。以下展示了一套完整的基准测试框架:

模型架构

import torch
from transformers import ViTForImageClassification, BertForSequenceClassification

# 模型配置字典
TORCH_MODELS = {
    'vit_torch': {
        'module': ViTForImageClassification, 
        'model': "google/vit-base-patch16-224",
        'framework': "torch", 
        'num_classes': 1000
    },
    'bert_torch': {
        'module': BertForSequenceClassification, 
        'model': "bert-base-uncased",
        'framework': "torch", 
        'num_classes': 2
    }
}

def loss_fn(params, model, batch):
    """综合损失函数定义"""
    if 'images' in batch.keys():
        # 视觉任务处理分支
        logits = torch.func.functional_call(
            model, params, (batch['images'],)
        ).logits
    else:
        # 自然语言处理分支
        logits = torch.func.functional_call(
            model, params, batch['input_ids'],
            kwargs={k: v for k, v in batch.items() if k != "input_ids"}
        ).logits
    
    # 交叉熵损失与权重衰减的结合
    loss = torch.nn.functional.cross_entropy(logits, batch['labels'])
    weight_decay = 0.0001
    weight_l2 = sum(p.norm()**2 for p in params.values() if p.ndim > 1)
    return loss + weight_decay * 0.5 * weight_l2

三种方法的pytorch实现

def get_hvp_forward_over_reverse(model, batch):
    """前向-反向模式HVP计算器构建"""
    def f(params):
        return loss_fn(params, model, batch)
    
    grad_fun = torch.func.grad(f)
    
    def hvp_fun(params, v):
        return torch.func.jvp(grad_fun, (params,), (v,))[1]
    
    return hvp_fun

def get_hvp_reverse_over_forward(model, batch):
    """反向-前向模式HVP计算器构建"""
    def f(params):
        return loss_fn(params, model, batch)
    
    def jvp_fun(params):
        return torch.func.jvp(f, (params,), (v,))[1]
    
    return torch.func.grad(jvp_fun)

def get_hvp_reverse_over_reverse(model, batch):
    """反向-反向模式HVP计算器构建"""
    def f(params):
        return loss_fn(params, model, batch)
    
    grad_fun = torch.func.grad(f)
    
    def hvp_fun(params, v):
        return torch.func.grad(
            lambda p: sum(
                torch.dot(a.ravel(), b.ravel())
                for a, b in zip(grad_fun(p).values(), v.values())
            ),
            argnums=0
        )(params)
    
    return hvp_fun

实验结果

通过大规模实验验证,我们获得了以下重要发现:

计算时间复杂度分析

计算方法时间复杂度相对基准梯度的倍率
基准梯度计算2×前向计算时间1.0×
前向-反向策略4×前向计算时间2.0×
反向-前向策略4×前向计算时间2.0×
反向-反向策略4×前向计算时间2.0×

关键:HVP计算成本仅为梯度计算的2-3倍,远远低于O(d²)复杂度的完整海塞矩阵构造。

内存使用模式对比

在禁用JIT编译优化的条件下:

  • 反向-反向策略:内存消耗最高(需要维护双重计算图)
  • 前向-反向策略:内存消耗适中
  • 反向-前向策略:内存消耗最低

值得注意的是,启用JIT编译后,三种策略的内存差异显著缩小。