本文已参与「新人创作礼」活动,一起开启掘金创作之路。
tensorflow图结点叫OP(operator)。OP是C++写的可以由使用者任意扩展的。扩展OP分两步,1是OP的声明,也就OP注册,使用REGISTER_OP来完成。2是OP的实现,叫op_kernel。KERNEL也需要注册,叫REGISTER_KERNEL_BUILDER。OP在实现时需要继承OpKernel类。
构图时只需要OP声明即可。运行时才需要查找并实例化Kernel。一个OP在不同的设备上可以有不同的实现。下面的例子是官网最简单的ZeroOut OP声明和Kernel的实现。实际上,声明和实现完全可以独立在不同的文件。OP注册在tensorflow之op_wyg_031113的博客-CSDN博客中进行了详细的分析。本文则着重分析Kernel
Kernel是真正实现计算功能的。
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
//OP的声明
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
//OP实现
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value if possible.
if (N > 0) output_flat(0) = input(0);
}
};
//注册KERNEL
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
Kernel接口
tensorflow/core/framework/op_kernel.h
同步计算 Compute方法
- kernel计算可以是同步也可以是异步。Compute必须是线程安全。大多数是同步。
- 同步 kernel 绝不能用锁,条件变量等阻塞当前线程,试图在其他kernel里解锁。有
- 因为executor可能只有固定数量的线程,都阻塞就会死锁
- 如果真想加锁,如RecvOp, DequeueOp,必须继承OpKernel的子类AsyncOpKernel。
- 大多数情况下,AsyncOpKerenl应当使用cancellation机制:context->cancellation_manager()
- op的输入输出都要通过参数OpKernelContext context来获得。返回状态也通过ctx->SetStatus()
- 同步计算中,context可以保证函数返回前直存在。
构造与析构
class OpKernel {
public:
//kernel不会在调度器中初始化,所以可以在子类中实现重逻辑
explicit OpKernel(OpKernelConstruction* context);
//允许延时OP. executor会使用OpKernelContext::inc_num_deferred_ops_function()` and
// `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
OpKernel(OpKernelConstruction* context, bool is_deferred);
//能请允许子类自定义NodeDef
OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
bool is_deferred);
virtual ~OpKernel();
//核心计算函数,子类重写它来实现自己的功能
virtual void Compute(OpKernelContext* context) = 0;
virtual AsyncOpKernel* AsAsync() { return nullptr; }
virtual bool IsExpensive() { return expensive_; }
virtual const Tensor* const_tensor() const { return nullptr; }
// Accessors. 能返回结点定义,结点名字,
const NodeDef& def() const { return props_->node_def; }
const std::string& name() const { return props_->node_def.name(); }
};
异步计算:AsyncOpKernel
异步也就是computeAsync要立即返回。当然tensorflow会一直保持context存在,直到done被调用。一但done被调用,不应当再使用context。否则会core。
class AsyncOpKernel : public OpKernel {
public:
using OpKernel::OpKernel; // Lift OpKernel constructors.
//异步计算完成后要调用此回调函数通知调度器。
//只能调用一次,一旦调用,context, 和this都可能已经销毁了
typedef std::function<void()> DoneCallback;
//异步计算就重写此接口
virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
AsyncOpKernel* AsAsync() override { return this; }
void Compute(OpKernelContext* context) override;
};
Kernel构造时的OpKernelConstruction
传入了
- 设备:device
- 分配器Allocator
- 资源管理器:ResourceMgr
- Node
- Env
- FunctionLib
//创建OP时由tensorflow框架创建此类,并传入给OP的构建函数。
class OpKernelConstruction {
public:
//环境,访问操作系统如文件系统,线程创建要使用此env.
Env* env() const { return device_->env(); }
void SetStatus(const Status& status);
const Status& status() const { return *status_; }
//属性时在构图时确定了的,所以在没有运行图时就能获取值。
template <class T>
Status GetAttr(StringPiece attr_name, T* value) const;
const DeviceType& device_type() const { return device_type_; }
FunctionLibraryRuntime* function_library() const { return flib_; }
// Shared resources accessible to this kernel.
ResourceMgr* resource_manager() const { return resource_mgr_; }
// The GraphDef version whose behavior we should follow.
int graph_def_version() const { return graph_def_version_; }
//获取设备
DeviceBase* device() const { return device_; }
};
OP输入输出参数帮助类
有的输入是个List,用一个名字,代表了同类型的多个输入。 可以认为是Tensor tensors[N].输出也有这种情况。
- OpInputList
- OpMutableInputList
- OpOutputList
Compute的参数OpKernelContext
这个类十分巨大,内容丰富。这个Context提供了Op Compute时所需要的一切。从逻辑上讲,可分为以下几类
输入输出参数获取
Input, Output. 至于Attr,是在构图时获得,OpKernelConstruction里就能获取
输出还涉及到Tensor内存分配
执行环境
env, device, resource_mgr, node, graph, session, step_id, function_library, allocator, session
class OpKernelContext {
public:
//基本信息
const SessionMetadata* session_metadata = nullptr;
TensorStore* tensor_store = nullptr;
explicit OpKernelContext(Params* params);
OpKernelContext(Params* params, int num_outputs);
~OpKernelContext();
Env* env() const { return params_->device->env(); }
int64_t step_id() const { return params_->step_id; }
int64_t start_time_usecs() const { return params_->start_time_usecs; }
// 操作op的输入,可以按id,或者名字,只读取或者读写Input
const Tensor& input(int index) const;
Status input(StringPiece name, const Tensor** tensor);
Status input_list(StringPiece name, OpInputList* list);
Status input_ref_mutex(StringPiece name, mutex** out_mutex);
Tensor mutable_input(int index, bool lock_held);
Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
Status mutable_input_list(StringPiece name, OpMutableInputList* list);
void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
Status replace_ref_input(StringPiece name, const Tensor& tensor,
bool lock_held);
void delete_ref_input(int input_index, bool lock_held);
bool has_input(int index) const;
// 操作op的输输出,可以按id,或者名字,只读取或者读写output. 同时可以给output分配内存
Status output_list(StringPiece name, OpOutputList* list);
Status allocate_output(int index, const TensorShape& shape,
Tensor** tensor) TF_MUST_USE_RESULT;
Status allocate_output(StringPiece name, const TensorShape& shape,
Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
AllocatorAttributes attr) TF_MUST_USE_RESULT;
Status allocate_output(StringPiece name, const TensorShape& shape,
Tensor** tensor,
AllocatorAttributes attr) TF_MUST_USE_RESULT;
//分配一个临时tensor变量
Status allocate_temp(DataType type, const TensorShape& shape,
Tensor* out_temp, AllocatorAttributes allocator_attr,
const AllocationAttributes& allocation_attr);
Status allocate_temp(DataType type, const TensorShape& shape,
Tensor* out_temp, AllocatorAttributes allocator_attr) {
return allocate_temp(type, shape, out_temp, allocator_attr,
AllocationAttributes());
}
Status allocate_temp(DataType type, const TensorShape& shape,
Tensor* out_temp) {
return allocate_temp(type, shape, out_temp, AllocatorAttributes());
}
};
kernel实例化
运行时调用如下方法创建Kernel
std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
DeviceBase* device,
Allocator* allocator,
const NodeDef& node_def,
int graph_def_version, Status* status);
std::unique_ptr<OpKernel> CreateOpKernel(
DeviceType device_type, DeviceBase* device, Allocator* allocator,
const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
Status* status);
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib,
const std::shared_ptr<const NodeProperties>& props,
int graph_def_version, OpKernel** kernel);
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib,
ResourceMgr* resource_mgr,
const std::shared_ptr<const NodeProperties>& props,
int graph_def_version, OpKernel** kernel);
Kernel注册同样使用了宏,工厂等
REGISTER_KERNEL_BUILDER流程分析
REGISTER_KERNEL_BUILDER 调用了REGISTER_KERNEL_BUILDER_IMPL 调用了TF_EXTRACT_KERNEL_NAME 调用了TF_EXTRACT_KERNEL_NAME_IMPL 调用了REGISTER_KERNEL_BUILDER_IMPL_2 调用了TF_NEW_ID_FOR_INIT
调用了REGISTER_KERNEL_BUILDER_IMPL_3
// REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
// TODO(dodgen): There are some uses of this macro inside functions, where
// kernel_builder refers to (non-const) locals (they should be fixed). To
// accommodate those, kernel_builder.Build() appears as an argument to an
// immediately-called lambda (not in the lambda itself).
#define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr, \
is_system_kernel, ...) \
static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr \
TF_ATTRIBUTE_UNUSED = \
TF_INIT_ON_STARTUP_IF(is_system_kernel || \
(SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
SHOULD_REGISTER_OP(op_name))) \
<< ([](::tensorflow::KernelDef const* kernel_def) { \
也就是到这里了,使用kernel_factory来注册了一个lambda函数 \
::tensorflow::kernel_factory::OpKernelRegistrar registrar( \
kernel_def, #__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { \
return new __VA_ARGS__(context); 这里就是在new ZeroOut \
}); \
(void)registrar; \
return ::tensorflow::InitOnStartupMarker{}; \
})(kernel_builder_expr.Build()); //这里的kernel_builder_expr就是KernelDefBuilder,其实就是Name("ZeroOut").Device(DEVICE_CPU).Build(); 而且这里是对lambda函数的调用,所以会立即进入函数内
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);这个定义中,Name实际上是KernelDefBuilder. Device就是KernelDefBuilder::Device.
REGISTER_KERNEL_BUILDER( KernelDefBuilder对象, ZeroOut这个类)。
OpkernelRegistrar的构建函数里最终调用到这个GlobalKernelRegistry的Reigster
void* GlobalKernelRegistry() {
static KernelRegistry* global_kernel_registry = []() {
KernelRegistry* registry = new KernelRegistry;
OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
return registry;
}();
return global_kernel_registry;
}
struct KernelRegistry {
mutex mu;
std::unordered_multimap<string, KernelRegistration> registry //就是放在这个map里了
TF_GUARDED_BY(mu);
};
从动态库加载kernel
tensorflow/core/framework/op_kernel.cc
加载目录:tensorflow/core/kernels目录中的所有so。实际上使用了Env->LoadDynamicLibrary 这种方式是我们扩展tensorflow kernel的方式。直接自己打包成独立的动态库,由tf加载即可。无须与tf源码编译到一起。
void LoadDynamicKernelsInternal() {
Env* env = Env::Default();
env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle)); //加载动态库。不同环境不现。比较linux上是加载so文件。windows是加载dll文件。
}
void LoadDynamicKernels() {
//只调用一次
static absl::once_flag dll_loader_flag;
absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
}
Kernel从context中获取输入,分配输出时返回错误
tensorflow/core/framework/op_requires.h中定义了大量的宏,帮助我们实现这些功能。这些宏能根据需要返回错误。这宏非常实用,避免我们写大量的if判断,return返回之类的代码
#define OP_REQUIRES_OK(CTX, ...) \
do { \
::tensorflow::Status _s(__VA_ARGS__); \
if (!TF_PREDICT_TRUE(_s.ok())) { \
CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
return; \
} \
} while (0)