torch_musa中ADVANCED_REGISTER解读

188 阅读2分钟

juejin.cn/post/731414… 所述,pytorch通过TORCH_LIBRARY_IMPL宏进行算子注册,如下面代码所示:

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) { 
    m.impl(<myadd_schema>, &myadd_autograd);
}

而在torch_musa中,则对TORCH_LIBRARY_IMPL进行封装,使用ADVANCED_REGISTER宏进行注册,如下面代码所示:

ADVANCED_REGISTER(aten, PrivateUse1, "max", MaxAll)

使用ADVANCED_REGISTER宏之后,用户可以通过设置环境变量打印所调用的算子信息,如输入输出、算子名字等,这相当于实现了类似于python的装饰器的功能——即给函数添加warpper。

注:ADVANCED_REGISTER宏由torch_musa团队的Yueran Tang所实现。

该宏定义在torch_musa/csrc/utils/register_wrapper.h文件中,如下代码所示:

#define REGISTER_IMPL(lib, key, yaml, func, name)               \
  using namespace at::musa;                                     \
  //为算子定义一个结构体                                          \
  template <class F, F f>                                       \
  struct wrapper_##name;                                        \   
  template <class R, class... Args, R (*f)(Args...)>            \
  //为和算子对应的结构体定义偏特化模板,                           \
  //板参数为函数指针类型和函数指针,                               \
  //下面定义的另一种偏特化模板的存在,                             \
  //传入该模板的函数的范围值类型不能为void                         \
  struct wrapper_##name<R (*)(Args...), f> {                    \   
    //该偏特化模板有一个warp静态函数,输入为函数的所有输入参数       \
    static R wrap(Args... args) {                               \   
      if (!GlobalConfig.IsOpEnabled(yaml, #func)) {             \
        return f(args...);                                      \
      }                                                         \
      GlobalConfig.CreateKernelStream(yaml, #func);             \
      GlobalConfig.enabled_ = false;                            \
      TraversalItems(args...);                                  \
      GlobalConfig.enabled_ = true;                             \
      GlobalConfig.SplitKernelIO();                             \
      //实际调用算子函数的地方                                    \
      R&& result = f(args...);                                  \
      GlobalConfig.enabled_ = false;                            \
      TraversalItems(result);                                   \
      GlobalConfig.enabled_ = true;                             \
      GlobalConfig.CloseKernelStream();                         \
      return std::forward<R>(result);                           \
    }                                                           \
  };                                                            \
  //为和算子对应的结构体定义另一种偏特化模板,                      \
  //其模板参数为函数指针类型(返回值类型为void)和函数指针           \
  template <class... Args, void (*f)(Args...)>                  \    
  struct wrapper_##name<void (*)(Args...), f> {                 \
    //该偏特化模板有一个warp静态函数,输入为函数的所有输入参数       \
    static void wrap(Args... args) {                            \    
      if (!GlobalConfig.IsOpEnabled(yaml, #func)) {             \
        return f(args...);                                      \
      }                                                         \
      GlobalConfig.CreateKernelStream(yaml, #func);             \
      GlobalConfig.enabled_ = false;                            \
      TraversalItems(args...);                                  \
      GlobalConfig.enabled_ = true;                             \
      GlobalConfig.SplitKernelIO();                             \
      //实际上调用函数的地方                                      \
      f(args...);                                               \
      GlobalConfig.CloseKernelStream();                         \
    }                                                           \
  };                                                            \
  TORCH_LIBRARY_IMPL(lib, key, m) {                             \
    //针对原本要注册的算子函数生成warpper结构体,                  \
    //然后将生成的warpper结构体的warp静态函数作为要注册的算子函数。 \
    //这个warpper中调用了原本要注册算子函数,并在前后进行了一些操作  \
    //如同python中的装饰器                                       \
    m.impl(yaml, &wrapper_##name<decltype(&func), func>::wrap); \  
  }

// lib = aten, key = PrivateUse1, yaml = torch op yaml name, func = kernel.
#define ADVANCED_REGISTER(lib, key, yaml, func) \
  REGISTER_IMPL(lib, key, yaml, func, func)