tensorflow源码分析之OP

418 阅读3分钟

 operator(op)是tensorflow扩展功能的的方式。OP分为声明和定义。声明叫op,实现叫kernel.一个声明可以有多个实现。或者说在不同设备上的不同实现。OP需要注册。

时刻注意,OP只是一个声明。如同C++的函数声明。并不涉及这些OP如何实现。比如可以声明一个OP叫Add,其功能是可以做两个数的加法int Add(int a, int b); 而这个声明用一个proto message表示就是message OpDef。而图就是多个OP的输入输出首尾相接组成的有向无环图,这个图实际上表示了函数的调用关系。

OP注册中心接口

只提供了根据名字查找OP的接口。tensorflow/core/framework/op.h


class OpRegistryInterface {
 public:
  virtual ~OpRegistryInterface();
  virtual Status LookUp(const std::string& op_type_name,
                        const OpRegistrationData** op_reg_data) const = 0;

  Status LookUpOpDef(const std::string& op_type_name,
                     const OpDef** op_def) const;
};



//实际的一个实现
class OpRegistry : public OpRegistryInterface {
 public:
  typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;

  OpRegistry();
  ~OpRegistry() override;

  void Register(const OpRegistrationDataFactory& op_data_factory);

  Status LookUp(const std::string& op_type_name,
                const OpRegistrationData** op_reg_data) const override;
  mutable mutex mu_;
  // Functions in deferred_ may only be called with mu_ held.
  mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_);
  // Values are owned.
  mutable std::unordered_map<string, const OpRegistrationData*> registry_
      TF_GUARDED_BY(mu_); //op就是注册在这里了
  mutable bool initialized_ TF_GUARDED_BY(mu_);

  // Registry watcher.
  mutable Watcher watcher_ TF_GUARDED_BY(mu_);

  std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
};

看这几个接口很简单,但是其参数OpDef, OpRegistrationData很复杂。

OpDef

一个op有多个输入参数,和多个输入属性,还有多个输出参数,多个控制输出。它们都是Tensor。

输入属性的值在构图时已经确定不变了。而输入参数是执行图时变化数据。

​编辑

class OpDef 是定义在proto中的。tensorflow/core/framework/op_def.proto

这个proto就声明了个OP.实际上就是把输入输出参数,OP名字等等元信息保存下来。


message OpDef {

  string name = 1; // op名字
  message ArgDef { // op输入输出参数
    string name = 1;
    string description = 2;
    DataType type = 3;
    string type_attr = 4;    // if specified, attr must have type "type"
    string number_attr = 5;  // if specified, attr must have type "int"
    string type_list_attr = 6;
    repeated ResourceHandleProto.DtypeAndShape handle_data = 7;
    bool is_ref = 16;
    FullTypeDef experimental_full_type = 17;
  }

  repeated ArgDef input_arg = 2;
  repeated ArgDef output_arg = 3;
  repeated string control_output = 20; //控制参数

  message AttrDef {   //op属性,构图时已经确定不变

    string name = 1;
    string type = 2;
    AttrValue default_value = 3;
    string description = 4;
    bool has_minimum = 5;
    int64 minimum = 6;

    AttrValue allowed_values = 7;
  }
  repeated AttrDef attr = 4; //属性

}

message OpDeprecation {
  int32 version = 1;

  string explanation = 2;
}

message OpList {  //一组op
  repeated OpDef op = 1;
}

OpDefBuilder来生成OP

Builder可以通过特定语法格式的字符串来添加 输入参数,输出参数等。添加完成后调用Finalize(OpRegistrationData* op_reg_data)生成了OpRegistrationData. OpRegistrationData有OpDef

tensorflow/core/framework/op_def_builder.h


// Builder class passed to the REGISTER_OP() macro.
class OpDefBuilder {
 public:
  explicit OpDefBuilder(std::string op_name);
  //定义属性
  OpDefBuilder& Attr(std::string spec);
  //定义输入输出
  OpDefBuilder& Input(std::string spec);
  OpDefBuilder& Output(std::string spec);

  OpRegistrationData op_reg_data_;
  std::vector<string> attrs_;
  std::vector<string> inputs_;
  std::vector<string> outputs_;
  std::vector<string> control_outputs_;
};

Op注册原理

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

REGISTER_OP宏,实际上定义了如下的OpDefBuilderWrapper的对象。

后续调用的.Input, .Output,等都是对此对象中的Input, Output的方法的调用。而Input里实现上转而调用了OpDefBuilder的Input。


namespace register_op {

class OpDefBuilderWrapper {
 public:
  explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
  //属性
  OpDefBuilderWrapper& Attr(std::string spec) {
    builder_.Attr(std::move(spec));
    return *this;
  }
  //输入
  OpDefBuilderWrapper& Input(std::string spec) {
    builder_.Input(std::move(spec));
    return *this;
  }
  //输出
  OpDefBuilderWrapper& Output(std::string spec) {
    builder_.Output(std::move(spec));
    return *this;
  }

  //下文中提到的InitOnStartupMarker 中调用了这个
  InitOnStartupMarker operator()();

 private:
  mutable ::tensorflow::OpDefBuilder builder_;
};

}  

#define REGISTER_OP_IMPL(ctr, name, is_system_op)                         \
  static ::tensorflow::InitOnStartupMarker const register_op##ctr         \
      TF_ATTRIBUTE_UNUSED =                                               \
          TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \
          << ::tensorflow::register_op::OpDefBuilderWrapper(name)

#define REGISTER_OP(name)        \
  TF_ATTRIBUTE_ANNOTATE("tf:op") \
  TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)


#define REGISTER_SYSTEM_OP(name)        \
  TF_ATTRIBUTE_ANNOTATE("tf:op")        \
  TF_ATTRIBUTE_ANNOTATE("tf:op:system") \
  TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true)

} 

  • REGISTER_OP这个宏调用TF_NEW_ID_FOR_INIT. 会使用__COUNTER__宏生成唯一ID.
  • 调用REGISTER_OP_IMPLE时,参数ctr就是counter。
  • REGISTER_OP_IMPLE所有定义了一个static变量。变量类型是 tensorlfow::InitOnStartUpMarker。变量名是register_op##ctr,实际上就是register_op0, register_op1, ....
  • TF_INIT_ON_STARTUP_IF宏如果参数是false,则什么也不做,否则 调用后边的<< OpeDefBuilder。这个宏根相当于:!cond ? InitOnStartupMarker{} : (InitOnStartupMarker{} << f);   f就是::tensorflow::register_op::OpDefBuilderWrapper(name)。因为InitOnStartUpmarker重载了operator<<。
  • 在下图代码InitOpStartupMarker里调用了OpDefBuilderWrapper的Operator()方法。
struct InitOnStartupMarker {
  constexpr InitOnStartupMarker operator<<(InitOnStartupMarker) const {
    return *this;
  }

  template <typename T>
  constexpr InitOnStartupMarker operator<<(T&& v) const {
    return std::forward<T>(v)();  #相当于调用OpDefBuilderWrapper对像的operator()
  }
};


// 是否在启动时就注册
#define TF_INIT_ON_STARTUP_IF(cond)                \
  (::std::integral_constant<bool, !(cond)>::value) \
      ? ::tensorflow::InitOnStartupMarker{}        \
      : ::tensorflow::InitOnStartupMarker {}

真正注册在这里:通过builder获取全局注册中心,实际上是不台OP的构建器function保存下来,在需要的时候就可以通过它来new出新的OP对象了。

InitOnStartupMarker OpDefBuilderWrapper::operator()() {
  OpRegistry::Global()->Register(
      [builder =
           std::move(builder_)](OpRegistrationData* op_reg_data) -> Status {
        return builder.Finalize(op_reg_data);
      });
  return {};
}

// static,单例
OpRegistry* OpRegistry::Global() {
  static OpRegistry* global_op_registry = new OpRegistry;
  return global_op_registry;
}

void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
  mutex_lock lock(mu_);
  if (initialized_) {
    TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
  } else {
    deferred_.push_back(op_data_factory);
  }
}

typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;

最终把构建好的op,其实就是OpRegistrationData插入到map<op_name, OpRegistrationData*>  OpRegistry::registry_中。

总结

op声明就是构建OpRegistrationData,其中需要添加输入输出,属性等等参数。为此OpDefBuilder来方便注册,可以一步步添加输入输出,最后调用个Finalize来生成OpRegistrationData。OP声明需要注册到OpRegistry中,通过调用宏REGISTER_OP来生成全局静态变量OpDefBuilderWrapper,在静态变量初始化时会把构建好的OpRegistrationData添加到OpRegistry中。