【深度学习编译】算子编译 IR 转换

1,230 阅读9分钟

文 @ 小 P 爱发呆

0 前言

今天想跟大家讨论的是深度学习领域编译技术中的关键过程之一:Intermediate Representation (IR) lowering

当下 AI 已然是大家普遍关注的话题之一,它既是我们茶余饭后的畅谈话题,也是顶级会议上资深学者们激情探讨的研究热点。

然而,AI 的发展除了依赖深度学习神经网络的不断优化、计算算力的提升,同时也依赖于深度学习网络大规模并行机制的高效性,以及深度学习网络中热点算子对高性能处理器的利用效率。业界也在不同的软硬件层次上提供了多样化的解决方案和支撑方式。其中,Deep Learning Framework + Vendor Highly Optimized Operator Libraries + Domain Specific Compiler (DSC) 的软件模式广泛为大家所接受。

Deep learing Framework 的 Eager mode 执行方式保证了用户可以在应用层使用基于高级编程语言丰富的表达能力。然而,Eager mode 虽然提供了易用性以及可表达性,性能却是短板。Vendor Optimized Operator Libraries 针对常见的性能瓶颈算子, 提供了基于硬件厂商处理器软硬件特性的高效实现,但是算子库往往依赖于领域专家的深度手工优化,需要耗费较多的时间和人力。

DSC 针对无法提前预知的 Tensor Computations,提供了一种自动生成可运行在加速器如 GPU 等多种后端高效代码的有效方式,这为多后端快速部署提供了可能性。不过,DSC 的 IR 往往只提供了上层编程语言的一个子集。

本文主要讨论在算子层,不同张量运算到 DSC 中间表示的 lowering 过程

说到深度学习算子优化,大家经常会想到两个典型算子,即卷积运算和矩阵乘法。

现在的常用计算库都提供了高性能卷积运算以及矩阵乘法运算,甚至囊括了对不同形状张量操作的优化。而本次讨论主要聚焦在长尾算子,即算法研究员在设计算法时,基于 Python 构建的由一系列 element-wise 操作、indexing 操作,以及 reduce 操作构成的函数。

通过深度学习编译器为长尾算子生成可运行在如 GPU 等高性能处理器上的优化代码,具有重大意义,原因在于:

  1. 长尾算子源源不断产生,算法研究员为了设计新的算法,会不断构建新的长尾算子;
  2. 长尾算子常常依赖于深度学习编译器框架的 Python API 构建,主要的操作为 element-wise、reduce 以及 indexing,但是函数可以实现的功能是多种多样,不可预测的。

不同的张量运算,其在 DSC 中的 IR 也往往有着不同的挑战。本文主要对长尾算子编译支持中的核心 IR 转换进行讨论,主要包括以下几个方面:

  • 主流 IR stack 探讨
  • 计算型算子 IR 转换
  • 访存型算子 IR 转换
  • 小结

1 主流 IR stack 探讨

深度学习编译器的主要流程如下图所示,这里主要以 Python 构建的长尾算子作为输入,以基于 schedule primitive 的 TVM IR 层为例来讲述。

表征的主要流程为:

  • 在 graph-level,从基于高级语言 Python 编写的源代码到 Expression-based Relay IR,主要会做一些图优化;
  • 在 operator-level,从 Relay IR 到 Tensor IR,主要为了对算子进行深度优化;
  • 将 Tensor IR 嵌入 Optimization info;
  • 将 Tensor IR 转换成接近后端编程语言的 statement IR;
  • 将 statement IR 转换成后端代码。

pipeline

值得注意的是,Python 语义允许张量 inplace 更新,存在非 SSA 的语义,而当前很多编译器都是基于 SSA IR,因此,在 Python 翻译时,需要非 SSA 到 SSA 的等价转换过程。

例如,Relay IR 便属于 SSA IR,再如,PyTorch 内嵌编译器中使用的中间表示 torchscript,定义了一套基于 SSA 的静态类型 IR,其可以等价表达 Python 语法特征的一个子集。即在翻译基于 Python 的算子时,会确定数据类型,并等价转换为 SSA IR。

DSC 的第一层中间表示往往记录了从 Python 源码翻译而来接近算子代码结构的详细信息,在算子层,基于计算序列解析将其转换成为代码生成后端可以识别的 Expression-based tensor IR,进而嵌入优化信息,生成后端代码。该 IR 将所有的计算序列都解析成表达式的形式,而基于 Python 编写的长尾算子包含的操作序列是多种多样的,IR lowering 则需要分析不同的计算访存特征进行 IR 转换。

我们根据操作类型是否包含算术或者逻辑运算,将操作序列解析分为计算型算子和访存型算子的解析。这方面的内容,将在接下来的小节中详细讨论。

2 计算型算子 IR 转换

计算型算子从 primitive IR 到 Expression-based tensor IR 的转换过程,即为将计算pattern 记录在算子计算表达式中的过程。

我们将常见的计算型算子分为单目运算、双目运算、条件表达式,以及设备端 API。相应地,可设计四类表达式,以对应记录的计算 pattern,如下代码段所示,其中单个小写字母表示某个特定张量。

  • 单目运算 需要传入一个输入张量,以及单目运算符类型;
  • 双目运算符 需要传入两个张量,以及双目运算符类型;
  • 条件表达式 需要传入条件表达式,以及两个输入表达式;
  • 设备端 API 需要传入参数列表,以及函数名。

这里,我们将运算符按类别分别定义在四个枚举类中,便于管理和实现。值得注意的是,在解析时,算子需要满足 Python 语义中的 broadcast 约束

    %u : UnaryType::Abs(%v)                  -> u = abs(v)          
    %c : BinaryType::Add(%a, %b)             -> c = a + b
    %t : Select(%x, %y, %z)                  -> t = x ? y : z
    %d : SpecFunc::Relu(%e)                  -> d = relu(e)

3 访存型算子 IR 转换

访存型算子 IR 转换的关键在于访存 pattern 的记录。

计算型算子的 IR 转换中,我们用到了常用的分类方法,并通过定义枚举类来统一管理不同的计算模式。而访存型算子很难通过这样的分类来统一管理,往往需要根据算子本身的访存 pattern 编写专门的复杂访存表达式,特别地,某些访存型算子源于 Python 复杂的语法特征,除了需要编写访存 pattern,还需要特定的 IR 设计和实现来支持。

举例来说:

  • cast: 类型转换操作,属于简单的访存算子,仅需记录转换类型以及输入张量。
  • stack: 将不确定数目的输入张量,按照给定的维度进行张量拼接,则需要定义访存表达式,将输入张量按次沿拼接维度拼接起来;
  • Numpy basic slicing and indexing 是一种重要且复杂的语法特征,其基于 view 语义,即会返回一个张量,但是该张量只是源张量的一个指向,并且记录了区别于源张量的访存方式。

它的表现形式主要在下标运算符中定义对维度的访问方式。如下代码段为从实际网络中截取的一段 Numpy-style indexing 综合实例热点代码,囊括了 3 种 basic indexing 的情况:

  • 下标运算符中 start:stop:step 形式引发的 slice 操作,如 offset[:,1::4]
  • None 关键字引发的添加大小为 1 的维度的操作,如 ctr_x[:, None]
  • basic indexing 发生在赋值表达式等号左侧而引发 inplace update 的操作,如最后一行赋值表达式, pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w

Numpy-style Indexing 综合实例:

    def offset2bbox(boxes, offset, weight=(1.0, 1.0, 1.0, 1.0)):
        ctr_x, ctr_y, widths, heights = xyxy2xywh(boxes)
        
        wx, wy, ww, wh = weight
        dx = offset[:, 0::4] / wx
        dy = offset[:, 1::4] / wy
        dw = offset[:, 2::4] / ww
        dh = offset[:, 3::4] / wh
        
        ...
        
        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
        pred_w = torch.exp(dw) * widths[:, None]
        pred_h = torch.exp(dh) * heights[:, None]
        
        ...
        
        pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
        
        ...
        
        return pred_boxes

上述举例的三种访存型算子,其中 cast 和 stack 需编写专门的访存表达式,而 Numpy Basic Indexing,如 Numpy-stype indexing 综合实例中所展示的一系列操作在代码生成流程中便引发了三个亟待解决的问题:

  • 基于 Python 的 view 语义在如 TVM 等生成框架中的表示问题;
  • 索引计算的问题;
  • inplace unpdate 的问题。

关于 indexing 引发的 view 表示问题以及 inplace update 问题,可以在 tensor IR 转换过程中设计 SSA IR 转换模块,以简化非 SSA 表示带来的复杂性,或者也可以在 tensor IR 内部新增 IR 设计,表示 view 语义特征以及 inplace udpate 更新方式。

索引计算问题则需要将 indexing 规则基于 expression-based tensor IR 更新在 tensor 的索引规则中,主要体现在张量的下标访问方式上。

这里简单例举了三种访存型算子,而实际上的访存型算子是复杂多变的,如 permute、reshape 等操作,均需要特殊支持。

4 小结

这篇小文给出了对于深度学习编译器在一种长尾算子代码生成场景中 IR stack 的讨论。虽然单个长尾算子在整个神经网络中可能耗时比例较小,但是神经网络中可能包含数十个甚至数百个长尾算子,并且在未来也会不断地创造这种需求。因此,基于编译手段,不断增强支持不同类型的算子翻译,具有重要意义。

当然,计算型算子和访存型算子有着不同的特征——计算型算子需要注意计算 pattern 的翻译,并需要满足形状信息约束;而访存型算子则需要注意访存 pattern 的翻译,有些操作潜在囊括于 Python 复杂的语法定义中,这时便需要在深度学习编译器中对应设计 IR 以支持 IR lowering 以及代码生成流程。

总而言之,通过上述讨论可以发现,对于访存型算子的翻译和支持往往是长尾算子高性能代码生成的痛点和难点。DSC 往往对张量运算有很好的表示和支持方式,本文的讨论也集中在张量运算上,但是,当我们面对越来越复杂的网络,DSC 对于非张量运算支持的需求是否也在不断增加呢,例如控制流?我们在哪个层面对其进行表示和优化会是更好的解决方案呢?笔者希望能够和大家有更多的探讨。


感谢阅读,欢迎在评论区留言讨论哦~

P.S. 如果喜欢本篇文章,请多多 点赞,让更多的人看见我们 :D

关注 公众号「SenseParrots」,获取人工智能框架前沿业界动态与技术思考。