operator(op)是tensorflow扩展功能的的方式。OP分为声明和定义。声明叫op,实现叫kernel.一个声明可以有多个实现。或者说在不同设备上的不同实现。OP需要注册。
时刻注意,OP只是一个声明。如同C++的函数声明。并不涉及这些OP如何实现。比如可以声明一个OP叫Add,其功能是可以做两个数的加法int Add(int a, int b); 而这个声明用一个proto message表示就是message OpDef。而图就是多个OP的输入输出首尾相接组成的有向无环图,这个图实际上表示了函数的调用关系。
OP注册中心接口
只提供了根据名字查找OP的接口。tensorflow/core/framework/op.h
class OpRegistryInterface {
public:
virtual ~OpRegistryInterface();
virtual Status LookUp(const std::string& op_type_name,
const OpRegistrationData** op_reg_data) const = 0;
Status LookUpOpDef(const std::string& op_type_name,
const OpDef** op_def) const;
};
//实际的一个实现
class OpRegistry : public OpRegistryInterface {
public:
typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
OpRegistry();
~OpRegistry() override;
void Register(const OpRegistrationDataFactory& op_data_factory);
Status LookUp(const std::string& op_type_name,
const OpRegistrationData** op_reg_data) const override;
mutable mutex mu_;
// Functions in deferred_ may only be called with mu_ held.
mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_);
// Values are owned.
mutable std::unordered_map<string, const OpRegistrationData*> registry_
TF_GUARDED_BY(mu_); //op就是注册在这里了
mutable bool initialized_ TF_GUARDED_BY(mu_);
// Registry watcher.
mutable Watcher watcher_ TF_GUARDED_BY(mu_);
std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
};
看这几个接口很简单,但是其参数OpDef, OpRegistrationData很复杂。
OpDef
一个op有多个输入参数,和多个输入属性,还有多个输出参数,多个控制输出。它们都是Tensor。
输入属性的值在构图时已经确定不变了。而输入参数是执行图时变化数据。
编辑
class OpDef 是定义在proto中的。tensorflow/core/framework/op_def.proto
这个proto就声明了个OP.实际上就是把输入输出参数,OP名字等等元信息保存下来。
message OpDef {
string name = 1; // op名字
message ArgDef { // op输入输出参数
string name = 1;
string description = 2;
DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
string type_list_attr = 6;
repeated ResourceHandleProto.DtypeAndShape handle_data = 7;
bool is_ref = 16;
FullTypeDef experimental_full_type = 17;
}
repeated ArgDef input_arg = 2;
repeated ArgDef output_arg = 3;
repeated string control_output = 20; //控制参数
message AttrDef { //op属性,构图时已经确定不变
string name = 1;
string type = 2;
AttrValue default_value = 3;
string description = 4;
bool has_minimum = 5;
int64 minimum = 6;
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4; //属性
}
message OpDeprecation {
int32 version = 1;
string explanation = 2;
}
message OpList { //一组op
repeated OpDef op = 1;
}
OpDefBuilder来生成OP
Builder可以通过特定语法格式的字符串来添加 输入参数,输出参数等。添加完成后调用Finalize(OpRegistrationData* op_reg_data)生成了OpRegistrationData. OpRegistrationData有OpDef
tensorflow/core/framework/op_def_builder.h
// Builder class passed to the REGISTER_OP() macro.
class OpDefBuilder {
public:
explicit OpDefBuilder(std::string op_name);
//定义属性
OpDefBuilder& Attr(std::string spec);
//定义输入输出
OpDefBuilder& Input(std::string spec);
OpDefBuilder& Output(std::string spec);
OpRegistrationData op_reg_data_;
std::vector<string> attrs_;
std::vector<string> inputs_;
std::vector<string> outputs_;
std::vector<string> control_outputs_;
};
Op注册原理
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
REGISTER_OP宏,实际上定义了如下的OpDefBuilderWrapper的对象。
后续调用的.Input, .Output,等都是对此对象中的Input, Output的方法的调用。而Input里实现上转而调用了OpDefBuilder的Input。
namespace register_op {
class OpDefBuilderWrapper {
public:
explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
//属性
OpDefBuilderWrapper& Attr(std::string spec) {
builder_.Attr(std::move(spec));
return *this;
}
//输入
OpDefBuilderWrapper& Input(std::string spec) {
builder_.Input(std::move(spec));
return *this;
}
//输出
OpDefBuilderWrapper& Output(std::string spec) {
builder_.Output(std::move(spec));
return *this;
}
//下文中提到的InitOnStartupMarker 中调用了这个
InitOnStartupMarker operator()();
private:
mutable ::tensorflow::OpDefBuilder builder_;
};
}
#define REGISTER_OP_IMPL(ctr, name, is_system_op) \
static ::tensorflow::InitOnStartupMarker const register_op##ctr \
TF_ATTRIBUTE_UNUSED = \
TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \
<< ::tensorflow::register_op::OpDefBuilderWrapper(name)
#define REGISTER_OP(name) \
TF_ATTRIBUTE_ANNOTATE("tf:op") \
TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)
#define REGISTER_SYSTEM_OP(name) \
TF_ATTRIBUTE_ANNOTATE("tf:op") \
TF_ATTRIBUTE_ANNOTATE("tf:op:system") \
TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true)
}
- REGISTER_OP这个宏调用TF_NEW_ID_FOR_INIT. 会使用__COUNTER__宏生成唯一ID.
- 调用REGISTER_OP_IMPLE时,参数ctr就是counter。
- REGISTER_OP_IMPLE所有定义了一个static变量。变量类型是 tensorlfow::InitOnStartUpMarker。变量名是register_op##ctr,实际上就是register_op0, register_op1, ....
- TF_INIT_ON_STARTUP_IF宏如果参数是false,则什么也不做,否则 调用后边的<< OpeDefBuilder。这个宏根相当于:!cond ? InitOnStartupMarker{} : (InitOnStartupMarker{} << f); f就是::tensorflow::register_op::OpDefBuilderWrapper(name)。因为InitOnStartUpmarker重载了operator<<。
- 在下图代码InitOpStartupMarker里调用了OpDefBuilderWrapper的Operator()方法。
struct InitOnStartupMarker {
constexpr InitOnStartupMarker operator<<(InitOnStartupMarker) const {
return *this;
}
template <typename T>
constexpr InitOnStartupMarker operator<<(T&& v) const {
return std::forward<T>(v)(); #相当于调用OpDefBuilderWrapper对像的operator()
}
};
// 是否在启动时就注册
#define TF_INIT_ON_STARTUP_IF(cond) \
(::std::integral_constant<bool, !(cond)>::value) \
? ::tensorflow::InitOnStartupMarker{} \
: ::tensorflow::InitOnStartupMarker {}
真正注册在这里:通过builder获取全局注册中心,实际上是不台OP的构建器function保存下来,在需要的时候就可以通过它来new出新的OP对象了。
InitOnStartupMarker OpDefBuilderWrapper::operator()() {
OpRegistry::Global()->Register(
[builder =
std::move(builder_)](OpRegistrationData* op_reg_data) -> Status {
return builder.Finalize(op_reg_data);
});
return {};
}
// static,单例
OpRegistry* OpRegistry::Global() {
static OpRegistry* global_op_registry = new OpRegistry;
return global_op_registry;
}
void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
mutex_lock lock(mu_);
if (initialized_) {
TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
} else {
deferred_.push_back(op_data_factory);
}
}
typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
最终把构建好的op,其实就是OpRegistrationData插入到map<op_name, OpRegistrationData*> OpRegistry::registry_中。
总结
op声明就是构建OpRegistrationData,其中需要添加输入输出,属性等等参数。为此OpDefBuilder来方便注册,可以一步步添加输入输出,最后调用个Finalize来生成OpRegistrationData。OP声明需要注册到OpRegistry中,通过调用宏REGISTER_OP来生成全局静态变量OpDefBuilderWrapper,在静态变量初始化时会把构建好的OpRegistrationData添加到OpRegistry中。