十倍模型计算时间仅增20%:OpenAI开源梯度替换插件

332 阅读8分钟
原文链接: mp.weixin.qq.com

选自GitHub

机器之心编译

参与: 蒋思源、李泽南

训练一个非常深度的神经网络需要大量内存。通过由 OpenAI 研究员 Tim Salimans 和 Yaroslav Bulatov 联合开发的工具包,你可以权衡计算力和内存的使用,从而使你的模型更合理地占用内存。对于前馈模型,我们能够借助该工具把大 10 多倍的模型放在我们的 GPU 上,而计算时间只增加 20%。

项目链接:https://github.com/openai/gradient-checkpointing

通过梯度检查节约内存

深度神经网络训练的内存密集部分是通过反向传播计算损失的梯度。通过查看由你的模型定义的计算图,并在反向传播中重计算这些结点,有可能在减少内存成本的同时计算对应结点的梯度。当训练的深度前馈神经网络包含 n 个层时,你可以这种方式把内存消耗降至 O(sqrt(n)),这需要执行一个额外的前馈传递作为代价(可参见 Training Deep Nets with Sublinear Memory Cost, by Chen et al. (2016))。通过使用 TensorFlow graph editor 自动重写反向传递的计算图,该库提供了 TensorFlow 的一个功能实现。

使用一般 tf.gradient 函数和我们的内存优化的梯度实现训练一个大批量的 ResNet 模型时占用的内存比。

工作原理

对一个简单的 n 层前馈神经网络,获取梯度的计算图如下所示:

神经网络的层级激活值对应于 f 标记的结点,且在正向传播过程中,所有这些结点需要按顺序计算。损失函数对激活值和这些层级参数的梯度使用 b 结点标记,且在反向传播过程中,所有这些结点需要按逆序计算。计算 f 结点的激活值是进一步计算 b 结点梯度的前提要求,因此 f 结点在前向传播后会保留在内存中。只有当反向传播执行地足够远以令计算对应的梯度不再需要使用后面层级的激活值或 f 的子结点时(如下图所示),这些激活值才能从内存中清除。这意味着简单的反向传播要求内存与神经网络的层级数成线性增长关系。下面我们展示了这些结点的计算顺序,紫色的结点表示在给定的时间内需要储存在内存中。

图 1:原版的反向传播

如上所述,简单的反向传播已经是计算最优的了,因为每个结点只需要计算一次。然而,如果我们愿意重新计算结点,那么我们可以节省大量的内存。当我们需要结点的激活值时,我们可以简单地重计算前向传播的结点激活值。我们可以按顺序执行计算,直到计算出需要使用激活值进行反向传播的结点。

图 2:占用内存少的反向传播

使用这一策略,需要令计算梯度的内存在神经网络层的数量 n 上是稳定的,且 n 在内存方面是最优的。但是要注意,结点的计算数量现在扩展了 n^2,相比于之前的 n。n 个结点中的每一个被再计算 n 次。因此计算图变得很慢以计算深度网络,使得这一方法不适用于深度学习。

为了在内存与计算之间取得平衡,我们需要一个策略允许结点被再计算,但是不太经常。这里我们使用的策略是把神经网络激活的一个子集标记为一个结点。

我们选择的检查点结点

这些检查点结点在前向传播后保留在内存中,而其余结点最多只会重新计算一次。在重新计算后,非检查点结点将保留在内存中,直到不再需要它们来执行反向传播。对于简单的前馈神经网络,所有神经元的激活结点都是由正向传播定义的连接点或图的分离点。这意味着我们在反向传播过程中只需要重计算 b 结点和最后检查点之间的结点,当反向传播达到了我们保存的检查点结点,那么所有从该结点开始重计算的结点在内存中都能够移除。计算和内存使用的顺序如下所示:

图 3:Checkpointed backprop

对于例子中的简单前馈网络,最好的选择是将每 qrt(n)-th 个结点作为 checkpoint。这样,checkpoint 结点的数量和 checkpoint 之间的结点数目都是 sqrt(n)的倍数,这意味着所需的内存现在也与我们网络中层数的平方根成比例。由于每个结点最多只能重算一次,因此该策略所需的额外算力相当于整个网络的单次正向传递。

OpenAI 的工具包实现了 checkpointed backprop,如图 3 所示。这是通过标准反向传播(图 1 所示)和 TensorFlow 图编辑器的自动重写实现的。对于包含关结点的图(单结点图分隔符),我们选择自动选择 checkpoints 的策略,使用 sqrt(n),提供 sqrt(n) 给前馈网络。对于只包含多结点分割的一般计算图,我们的 checkpointed backprop 实现仍然有效,但目前仍需使用者手动选择 checkpoint。

更多的计算图、内存用量和梯度计算策略说明可以在这篇文章中找到:https://medium.com/@yaroslavvb/fitting-larger-networks-into-memory-583e3c758ff9。

设置需求

  1. pip install tf-nightly-gpu

  2. pip install toposort networkx pytest

在运行测试的时候,保证能建立 CUDA Profiling Tool Interface(CUPTI),例如,通过运行 export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/extras/CUPTI/lib64"。

使用

本项目提供了一个 TensorFlow 中 tf.gradients 的插入式替换。载入此函数需要:

  1. from memory_saving_gradients import gradients

随后使用 gradients 函数,就像你正常使用 tf.gradients 来计算梯度损失参数一样。(这里假设你明确地调用 tf.gradients,而不是将其隐藏在 tf.train.Optimizer 中。)

除了 tf.gradients 的常规参数以外,OpenAI 的 gradients 函数还有一个额外的参数 checkpoints。Checkpoints 参数告诉 gradients 函数计算图中的哪个结点在前向传播中需要检查。检查点之间的结点会在反向传播时计算。你可以为 checkpoint 提供一个张量列表,gradients(ys,xs,checkpoints=[tensor1,tensor2]),或使用以下关键词:

  • ‘collection(默认)’:这个 checkpoint 的所有张量返回 tf.get_collection('checkpoints')。你随后需要确认自己在定义自己的模型时是使用 tf.add_to_collection('checkpoints', tensor) 来加入张量的。

  • ‘memory’:它使用启发式机制来自动选择 checkpoint 的结点,从而达到我们需要的内存用量 O(sqrt(n))。启发式方法是通过自动识别图中的「关结点」来实现的,即移除时将计算图分成两个断开的张量,然后对这些张量进行检查点确定,找到一个合适的数量。这种方式目前在很多模型上运行良好(但不是所有)。

  • ‘speed’:这个选项试图通过检查所有操作的输出来最大化运行速度,这通常非常耗费算力,特别是在卷积和矩阵乘法上。

覆盖 TF.GRADIENTS

直接使用 gradients 新函数的另一个方法是直接覆盖 Python 上注册的 tf.gradients 函数名。就像这样:

  1. import tensorflow as tf

  2. import memory_saving_gradients

  3. # monkey patch tf.gradients to point to our custom version, with automatic checkpoint selection

  4. def gradients_memory(ys, xs, grad_ys=None, **kwargs):

  5.  return memory_saving_gradients.gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs)

  6. tf.__dict__["gradients"] = gradients_memory

这样,所有 tf.gradients 的调用就会使用节约内存的版本作为代替了。

测试

在 GitHub 资源的测试文件夹中包含用于测试代码准确性,并分析各类模型内存使用情况的脚本。修改代码后,你可以从该文件夹运行./run_all_tests.sh 来进行测试。

下图展示了在 CIFAR10 上运行不同层数 ResNet 的内存用量和时间,Batch-size 为 1280,GPU 为 GeForce GTX 1080:

限制

目前提供的代码在运行模型之前全部使用 Python 进行图操作,这会导致大型图处理速度缓慢。当前用于自动选择 checkpoint 的算法是纯启发式的,预计在已有测试之外的一些模型上可能会失败。在这种情况下,我们应该使用手动选择 checkpoint 的方式。

参考内容

  • Academic papers describing checkpointed backpropagation: Training Deep Nets with Sublinear Memory Cost, by Chen et al. (2016) (https://arxiv.org/pdf/1604.06174.pdf), Memory-Efficient Backpropagation Through Time, by Gruslys et al. (2016) (https://arxiv.org/abs/1606.03401v1)

  • Explanation of using graph_editor to implement checkpointing on TensorFlow graphs: https://github.com/tensorflow/tensorflow/issues/4359#issuecomment-269241038, https://github.com/yaroslavvb/stuff/blob/master/simple_rewiring.ipynb

  • Experiment code/details: https://medium.com/@yaroslavvb/testing-memory-saving-on-v100-8aa716bbdf00

  • TensorFlow memory tracking package: https://github.com/yaroslavvb/chain_constant_memory/blob/master/mem_util_test.py

  • Implementation of "memory-poor" backprop strategy in TensorFlow for a simple feed-forward net: https://github.com/yaroslavvb/chain_constant_memory/

本文为机器之心编译,转载请联系本公众号获得授权

✄------------------------------------------------

加入机器之心(全职记者/实习生):hr@jiqizhixin.com

投稿或寻求报道:content@jiqizhixin.com

广告&商务合作:bd@jiqizhixin.com