一、前言
torch.compile()是 PyTorch 团队的一项重大创新。我们可以将它看作是模型的编译器,它与模型的关系类似于 GCC 和 C、C++的关系,只不过它专为优化深度学习模型设计。这项突破让 PyTorch 模型比以往更快且更高效。
而最好的地方在于,使用torch.compile()非常简单。你只需要在 PyTorch 或 Huggingface 模型上加一行代码:
import torch
model = torch.compile(model)
虽然这行代码会带来额外的开销,但它的好处远高于其带来的开销。编译后的模型模型更快,让模型在推理和训练时有更好的表现。
二、torch.compile()如何提速模型?
现在我们知道用 torch.compile()可以提速模型。这种提升主要来自两个因素:
- 减少 Python 开销
- 优化 GPU 读写
下面我们详细看看细节。
2.1 减少Python开销
在不使用 torch.compile()时,PyTorch 在 Eager 模式下运行,此时 Python 解释器逐行执行代码。这个模式下,解释器不知道后面要执行的代码。这种缺乏远见的执行方式效率低下,因为无法全局进行优化。
在使用 torch.compile()后,会一次性分析整个模型。编译器会审查模型的所有操作。这种全局视野允许优化执行过程。 torch.compile()将整个模型当做一个整体,并删除不必要的步骤,从而优化的过程。
2.2 优化GPU读写
为了理解 torch.compile()如何优化 GPU 读写,我们需要先深入理解 GPU 如何处理数据和计算。
2.2.1 GPU如何工作
让我们用一个简单的解释来解释一下:
所有电脑都有 CPU,CPU 通常有几个强大的核心组成,这些核心与内存连接。CPU 和内存通信以执行操作,这个我们已经很熟悉了。
GPU 和 CPU 在架构和内存设计上不同:
- 核心:GPU 有很多简单的核心而非少数 CPU 中强大的核心
- 内存:GPU 用 Hight Bandwith Memory(HBM),比传统的 DRAM 快的多,但是容量有限制
2.2.2 CPU 如何计算
- 所有模型参数、激活值和其他变量存储到 GPU 的 HBM 中
- 执行任意操作,都要求输入数据从 HBM 传输到 GPU 核心
- 在 GPU 内部,数据通过 GPU 芯片上的缓存和寄存器进行计算
- 计算完成后,结果协会 HBM
HBM 和 GPU 核心间持续的数据传输是非常耗时的,而且非常影响表现。尽管 GPU 核心内的操作非常快,但 HBM 和 GPU 核心之间的内存带宽通常会成为瓶颈。
2.2.3 面临挑战:GPU内存带宽
每个 GPU 的带宽都是固定的,它定义了数据在其内存(HBM)和计算核心之间移动的速率。这成了深度学习任务中的关键瓶颈,因为深度学习任务需要频繁访问内存:
- 加载输入数据到 GPU 核心
- 在 GPU 中执行计算
- 将输出写回到 HBM
在 使用torch.compile()时,PyTorch 处理 Eager 模式顺序处理操作,这种来回读写会非常低效,显著提升了完整操作的时间。
2.3 torch.compile()如何解决频繁读写
torch.compile()通过两个关键方法优化 GPU 读写:
- 减少内存传输:编译器预先分析整个计算图,并最大限度地减少 HBM 和 GPU 核心之间不必要的数据传输
- 核融合:多个操作被融合到一个 GPU 核心中,减少核心和 HBM 之间的传输
2.3.1 核融合:关键优化
核融合将多个操作合并到一个核心中。融合核不再需要执行多次来回的内存操作,而是一次性执行所有操作。这降低了开销并提高了执行效率。
2.3.2 为什么重要?
为了更好地理解影响,请考虑 GPU 内存层次结构:
- HBM:速度快但容量有限(约 40-80GB)
- 片上存储器:速度极快但很小(约几十 MB)
这些内存层之间频繁的数据移动会导致效率低下。torch.compile() 通过简化数据传输和减少内存带宽瓶颈来解决此问题。
三、结论
torch.compile()通过一下方式让 PyTorch 模型运行更快:
- 减少 Python 开销:将整个模型编译成优化对象可消除动态解释中的低效率。
- 优化 GPU 读写:像核融合技术极大减少了内存传输并且极大提升了 GPU 表现。
只需一行代码,就可以实现令人印象深刻的加速,并充分发挥硬件在深度学习工作负载方面的潜力。无论是在训练大规模模型还是大规模运行推理,torch.compile() 都是应该添加到深度学习库中的强大工具。