如juejin.cn/post/731414… 所述,pytorch通过TORCH_LIBRARY_IMPL宏进行算子注册,如下面代码所示:
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl(<myadd_schema>, &myadd_autograd);
}
而在torch_musa中,则对TORCH_LIBRARY_IMPL进行封装,使用ADVANCED_REGISTER宏进行注册,如下面代码所示:
ADVANCED_REGISTER(aten, PrivateUse1, "max", MaxAll)
使用ADVANCED_REGISTER宏之后,用户可以通过设置环境变量打印所调用的算子信息,如输入输出、算子名字等,这相当于实现了类似于python的装饰器的功能——即给函数添加warpper。
注:ADVANCED_REGISTER宏由torch_musa团队的Yueran Tang所实现。
该宏定义在torch_musa/csrc/utils/register_wrapper.h文件中,如下代码所示:
#define REGISTER_IMPL(lib, key, yaml, func, name) \
using namespace at::musa; \
//为算子定义一个结构体 \
template <class F, F f> \
struct wrapper_##name; \
template <class R, class... Args, R (*f)(Args...)> \
//为和算子对应的结构体定义偏特化模板, \
//板参数为函数指针类型和函数指针, \
//下面定义的另一种偏特化模板的存在, \
//传入该模板的函数的范围值类型不能为void \
struct wrapper_##name<R (*)(Args...), f> { \
//该偏特化模板有一个warp静态函数,输入为函数的所有输入参数 \
static R wrap(Args... args) { \
if (!GlobalConfig.IsOpEnabled(yaml, #func)) { \
return f(args...); \
} \
GlobalConfig.CreateKernelStream(yaml, #func); \
GlobalConfig.enabled_ = false; \
TraversalItems(args...); \
GlobalConfig.enabled_ = true; \
GlobalConfig.SplitKernelIO(); \
//实际调用算子函数的地方 \
R&& result = f(args...); \
GlobalConfig.enabled_ = false; \
TraversalItems(result); \
GlobalConfig.enabled_ = true; \
GlobalConfig.CloseKernelStream(); \
return std::forward<R>(result); \
} \
}; \
//为和算子对应的结构体定义另一种偏特化模板, \
//其模板参数为函数指针类型(返回值类型为void)和函数指针 \
template <class... Args, void (*f)(Args...)> \
struct wrapper_##name<void (*)(Args...), f> { \
//该偏特化模板有一个warp静态函数,输入为函数的所有输入参数 \
static void wrap(Args... args) { \
if (!GlobalConfig.IsOpEnabled(yaml, #func)) { \
return f(args...); \
} \
GlobalConfig.CreateKernelStream(yaml, #func); \
GlobalConfig.enabled_ = false; \
TraversalItems(args...); \
GlobalConfig.enabled_ = true; \
GlobalConfig.SplitKernelIO(); \
//实际上调用函数的地方 \
f(args...); \
GlobalConfig.CloseKernelStream(); \
} \
}; \
TORCH_LIBRARY_IMPL(lib, key, m) { \
//针对原本要注册的算子函数生成warpper结构体, \
//然后将生成的warpper结构体的warp静态函数作为要注册的算子函数。 \
//这个warpper中调用了原本要注册算子函数,并在前后进行了一些操作 \
//如同python中的装饰器 \
m.impl(yaml, &wrapper_##name<decltype(&func), func>::wrap); \
}
// lib = aten, key = PrivateUse1, yaml = torch op yaml name, func = kernel.
#define ADVANCED_REGISTER(lib, key, yaml, func) \
REGISTER_IMPL(lib, key, yaml, func, func)