tensorflow源码分析之kernel

1,114 阅读5分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

 tensorflow图结点叫OP(operator)。OP是C++写的可以由使用者任意扩展的。扩展OP分两步,1是OP的声明,也就OP注册,使用REGISTER_OP来完成。2是OP的实现,叫op_kernel。KERNEL也需要注册,叫REGISTER_KERNEL_BUILDER。OP在实现时需要继承OpKernel类。

构图时只需要OP声明即可。运行时才需要查找并实例化Kernel。一个OP在不同的设备上可以有不同的实现。下面的例子是官网最简单的ZeroOut OP声明和Kernel的实现。实际上,声明和实现完全可以独立在不同的文件。OP注册在tensorflow之op_wyg_031113的博客-CSDN博客中进行了详细的分析。本文则着重分析Kernel

Kernel是真正实现计算功能的。

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
//OP的声明
REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

//OP实现
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};
//注册KERNEL
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

Kernel接口

tensorflow/core/framework/op_kernel.h

同步计算 Compute方法

  1. kernel计算可以是同步也可以是异步。Compute必须是线程安全。大多数是同步。
  2. 同步 kernel 绝不能用锁,条件变量等阻塞当前线程,试图在其他kernel里解锁。有
  3. 因为executor可能只有固定数量的线程,都阻塞就会死锁
  4. 如果真想加锁,如RecvOp, DequeueOp,必须继承OpKernel的子类AsyncOpKernel。
  5. 大多数情况下,AsyncOpKerenl应当使用cancellation机制:context->cancellation_manager()
  6. op的输入输出都要通过参数OpKernelContext context来获得。返回状态也通过ctx->SetStatus()
  7. 同步计算中,context可以保证函数返回前直存在。
构造与析构
class OpKernel {
 public:
  //kernel不会在调度器中初始化,所以可以在子类中实现重逻辑
  explicit OpKernel(OpKernelConstruction* context);

  //允许延时OP. executor会使用OpKernelContext::inc_num_deferred_ops_function()` and
  // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
  OpKernel(OpKernelConstruction* context, bool is_deferred);
  //能请允许子类自定义NodeDef
  OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
           bool is_deferred);

  virtual ~OpKernel();
     
  //核心计算函数,子类重写它来实现自己的功能
  virtual void Compute(OpKernelContext* context) = 0;
  virtual AsyncOpKernel* AsAsync() { return nullptr; }
  virtual bool IsExpensive() { return expensive_; }
  virtual const Tensor* const_tensor() const { return nullptr; }
  // Accessors. 能返回结点定义,结点名字,
  const NodeDef& def() const { return props_->node_def; }
  const std::string& name() const { return props_->node_def.name(); }
};

异步计算:AsyncOpKernel

异步也就是computeAsync要立即返回。当然tensorflow会一直保持context存在,直到done被调用。一但done被调用,不应当再使用context。否则会core。

class AsyncOpKernel : public OpKernel {
 public:
  using OpKernel::OpKernel;  // Lift OpKernel constructors.

  //异步计算完成后要调用此回调函数通知调度器。
  //只能调用一次,一旦调用,context, 和this都可能已经销毁了
  typedef std::function<void()> DoneCallback;
  //异步计算就重写此接口
  virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;

  AsyncOpKernel* AsAsync() override { return this; }

  void Compute(OpKernelContext* context) override;
};

Kernel构造时的OpKernelConstruction

传入了

  1. 设备:device
  2. 分配器Allocator
  3. 资源管理器:ResourceMgr
  4. Node
  5. Env
  6. FunctionLib
//创建OP时由tensorflow框架创建此类,并传入给OP的构建函数。
class OpKernelConstruction {
 public:

  //环境,访问操作系统如文件系统,线程创建要使用此env.
  Env* env() const { return device_->env(); }
 
  void SetStatus(const Status& status);
  const Status& status() const { return *status_; }
  //属性时在构图时确定了的,所以在没有运行图时就能获取值。
  template <class T>
  Status GetAttr(StringPiece attr_name, T* value) const;
  const DeviceType& device_type() const { return device_type_; }

  FunctionLibraryRuntime* function_library() const { return flib_; }

  // Shared resources accessible to this kernel.
  ResourceMgr* resource_manager() const { return resource_mgr_; }

  // The GraphDef version whose behavior we should follow.
  int graph_def_version() const { return graph_def_version_; }

  //获取设备
  DeviceBase* device() const { return device_; }

};

OP输入输出参数帮助类

有的输入是个List,用一个名字,代表了同类型的多个输入。 可以认为是Tensor  tensors[N].输出也有这种情况。

  1. OpInputList
  2. OpMutableInputList
  3. OpOutputList

Compute的参数OpKernelContext

这个类十分巨大,内容丰富。这个Context提供了Op Compute时所需要的一切。从逻辑上讲,可分为以下几类

输入输出参数获取

Input, Output. 至于Attr,是在构图时获得,OpKernelConstruction里就能获取

输出还涉及到Tensor内存分配

执行环境

env, device, resource_mgr, node, graph, session, step_id, function_library, allocator, session


class OpKernelContext {
 public:
  //基本信息
   const SessionMetadata* session_metadata = nullptr;
   TensorStore* tensor_store = nullptr;
  explicit OpKernelContext(Params* params);
  OpKernelContext(Params* params, int num_outputs);
  ~OpKernelContext();
  Env* env() const { return params_->device->env(); }
  int64_t step_id() const { return params_->step_id; }
  int64_t start_time_usecs() const { return params_->start_time_usecs; }



  // 操作op的输入,可以按id,或者名字,只读取或者读写Input

  const Tensor& input(int index) const;
  Status input(StringPiece name, const Tensor** tensor);
  Status input_list(StringPiece name, OpInputList* list);
  Status input_ref_mutex(StringPiece name, mutex** out_mutex);
  Tensor mutable_input(int index, bool lock_held);
  Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
  Status mutable_input_list(StringPiece name, OpMutableInputList* list);
  void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
  Status replace_ref_input(StringPiece name, const Tensor& tensor,
                           bool lock_held);
  void delete_ref_input(int input_index, bool lock_held);
  bool has_input(int index) const;

 // 操作op的输输出,可以按id,或者名字,只读取或者读写output. 同时可以给output分配内存
  Status output_list(StringPiece name, OpOutputList* list);
  Status allocate_output(int index, const TensorShape& shape,
                         Tensor** tensor) TF_MUST_USE_RESULT;
  Status allocate_output(StringPiece name, const TensorShape& shape,
  Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
                         AllocatorAttributes attr) TF_MUST_USE_RESULT;
  Status allocate_output(StringPiece name, const TensorShape& shape,
                         Tensor** tensor,
                         AllocatorAttributes attr) TF_MUST_USE_RESULT;
  //分配一个临时tensor变量
  Status allocate_temp(DataType type, const TensorShape& shape,
                       Tensor* out_temp, AllocatorAttributes allocator_attr,
                       const AllocationAttributes& allocation_attr);
  Status allocate_temp(DataType type, const TensorShape& shape,
                       Tensor* out_temp, AllocatorAttributes allocator_attr) {
    return allocate_temp(type, shape, out_temp, allocator_attr,
                         AllocationAttributes());
  }
  Status allocate_temp(DataType type, const TensorShape& shape,
                       Tensor* out_temp) {
    return allocate_temp(type, shape, out_temp, AllocatorAttributes());
  }

};

kernel实例化

运行时调用如下方法创建Kernel

std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
                                         DeviceBase* device,
                                         Allocator* allocator,
                                         const NodeDef& node_def,
                                         int graph_def_version, Status* status);

std::unique_ptr<OpKernel> CreateOpKernel(
    DeviceType device_type, DeviceBase* device, Allocator* allocator,
    const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
    Status* status);

Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
                      Allocator* allocator, FunctionLibraryRuntime* flib,
                      const std::shared_ptr<const NodeProperties>& props,
                      int graph_def_version, OpKernel** kernel);

Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
                      Allocator* allocator, FunctionLibraryRuntime* flib,
                      ResourceMgr* resource_mgr,
                      const std::shared_ptr<const NodeProperties>& props,
                      int graph_def_version, OpKernel** kernel);

Kernel注册同样使用了宏,工厂等

REGISTER_KERNEL_BUILDER流程分析

REGISTER_KERNEL_BUILDER 调用了REGISTER_KERNEL_BUILDER_IMPL 调用了TF_EXTRACT_KERNEL_NAME 调用了TF_EXTRACT_KERNEL_NAME_IMPL 调用了REGISTER_KERNEL_BUILDER_IMPL_2 调用了TF_NEW_ID_FOR_INIT
调用了REGISTER_KERNEL_BUILDER_IMPL_3

// REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
// TODO(dodgen): There are some uses of this macro inside functions, where
// kernel_builder refers to (non-const) locals (they should be fixed). To
// accommodate those, kernel_builder.Build() appears as an argument to an
// immediately-called lambda (not in the lambda itself).
#define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr,   \
                                       is_system_kernel, ...)               \
  static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr      \
      TF_ATTRIBUTE_UNUSED =                                                 \
          TF_INIT_ON_STARTUP_IF(is_system_kernel ||                         \
                                (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
                                 SHOULD_REGISTER_OP(op_name)))              \
          << ([](::tensorflow::KernelDef const* kernel_def) {               \
               也就是到这里了,使用kernel_factory来注册了一个lambda函数      \
               ::tensorflow::kernel_factory::OpKernelRegistrar registrar(   \
                   kernel_def, #__VA_ARGS__,                                \
                   [](::tensorflow::OpKernelConstruction* context)          \
                       -> ::tensorflow::OpKernel* {                         \
                     return new __VA_ARGS__(context); 这里就是在new ZeroOut               \
                   });                                                      \
               (void)registrar;                                             \
               return ::tensorflow::InitOnStartupMarker{};                  \
             })(kernel_builder_expr.Build()); //这里的kernel_builder_expr就是KernelDefBuilder,其实就是Name("ZeroOut").Device(DEVICE_CPU).Build(); 而且这里是对lambda函数的调用,所以会立即进入函数内

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);这个定义中,Name实际上是KernelDefBuilder. Device就是KernelDefBuilder::Device. 

REGISTER_KERNEL_BUILDER( KernelDefBuilder对象, ZeroOut这个类)。

OpkernelRegistrar的构建函数里最终调用到这个GlobalKernelRegistry的Reigster

void* GlobalKernelRegistry() {
  static KernelRegistry* global_kernel_registry = []() {
    KernelRegistry* registry = new KernelRegistry;
    OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
    return registry;
  }();
  return global_kernel_registry;
}

struct KernelRegistry {
  mutex mu;
  std::unordered_multimap<string, KernelRegistration> registry //就是放在这个map里了
      TF_GUARDED_BY(mu);
};

从动态库加载kernel

tensorflow/core/framework/op_kernel.cc

加载目录:tensorflow/core/kernels目录中的所有so。实际上使用了Env->LoadDynamicLibrary 这种方式是我们扩展tensorflow kernel的方式。直接自己打包成独立的动态库,由tf加载即可。无须与tf源码编译到一起。


void LoadDynamicKernelsInternal() {
  Env* env = Env::Default();
  env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle)); //加载动态库。不同环境不现。比较linux上是加载so文件。windows是加载dll文件。
}

void LoadDynamicKernels() {
  //只调用一次
  static absl::once_flag dll_loader_flag;
  absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
}

 Kernel从context中获取输入,分配输出时返回错误

tensorflow/core/framework/op_requires.h中定义了大量的宏,帮助我们实现这些功能。这些宏能根据需要返回错误。这宏非常实用,避免我们写大量的if判断,return返回之类的代码

#define OP_REQUIRES_OK(CTX, ...)                             \
  do {                                                       \
    ::tensorflow::Status _s(__VA_ARGS__);                    \
    if (!TF_PREDICT_TRUE(_s.ok())) {                         \
      CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
      (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s);  \
      return;                                                \
    }                                                        \
  } while (0)