tensorflow对图的优化包括MetaOptimizer、GraphOptimizationPass、GraphOptimizer,这里先看GraphOptimizationPass的框架实现
优化器注册
OptimizationPassRegistry
class OptimizationPassRegistry {
public:
// Groups of passes are run at different points in initialization.
enum Grouping { // 四个阶段
PRE_PLACEMENT, // after cost model assignment, before placement.
POST_PLACEMENT, // after placement.
POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints.
POST_PARTITIONING, // after partitioning
};
// Add an optimization pass to the registry.
void Register(Grouping grouping, int phase,
std::unique_ptr<GraphOptimizationPass> pass); // 把pass塞到groups_
const std::map<Grouping, GraphOptimizationPasses>& groups() {
return groups_;
}
// Run all passes in grouping, ordered by phase, with the same
// options.
Status RunGrouping(Grouping grouping,
const GraphOptimizationPassOptions& options);
// 全局registry,单例
static OptimizationPassRegistry* Global();
// Prints registered optimization passes for debugging.
void LogGrouping(Grouping grouping, int vlog_level);
void LogAllGroupings(int vlog_level);
private:
// key: 四个group pass阶段
// value:GraphOptimizationPasses是一个map(int, vector<GraphOptimizationPass>)
// key: phase number, 升序排列
// value: GraphOptimizationPass指针集合
std::map<Grouping, GraphOptimizationPasses> groups_; // 优化路径到优化
};
GraphOptimizationPassOptions
作为优化器的输入,保存了图和session等的一些信息
struct GraphOptimizationPassOptions {
string session_handle; // DirectSession在PRE_PLACEMENT阶段填充
const SessionOptions* session_options = nullptr;
const CostModel* cost_model = nullptr;
FunctionLibraryDefinition* flib_def = nullptr; // Not owned.
const DeviceSet* device_set = nullptr; // Not owned.
std::unique_ptr<Graph>* graph = nullptr;
std::unordered_map<string, std::unique_ptr<Graph>>*
partition_graphs = nullptr;
// Indicator of whether or not the graph was derived from a function.
bool is_function_graph = false;
};
GraphOptimizationPass
优化器基类,优化器都是由它派生而来,添加新的优化器需要继承这个基类
这个类只实现了一个纯虚的Run方法,接收GraphOptimizationPassOptions方法
class GraphOptimizationPass {
public:
virtual ~GraphOptimizationPass() {}
virtual Status Run(const GraphOptimizationPassOptions& options) = 0; // 纯虚
void set_name(const string& name) { name_ = name; }
string name() const { return name_; }
private:
// The name of the optimization pass, which is the same as the inherited
// class name.
string name_; // 优化器名字
};
注册方式
同op注册一样,OptimizationPass也使用宏定义方式注册
在各具体的优化器文件中调用注册优化器
// 宏__COUNTER__实质上是一个int,并且是具体的数,初值是0,每预编译一次其值自己加1
// 这里保证变量名唯一
#define REGISTER_OPTIMIZATION(grouping, phase, optimization) \
REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization)
#define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)
#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \
static ::tensorflow::optimization_registration::OptimizationPassRegistration \
register_optimization_##ctr( \
grouping, phase, \
::std::unique_ptr<::tensorflow::GraphOptimizationPass>( \
new optimization()), \
#optimization)
比如,注册一个grouping为PRE_PLACEMENT,phase为20,名字为OptimizationExample的优化器,那么实际生成的代码是:
static ::tensorflow::optimization_registration::OptimizationPassRegistration
register_optimization_0(PRE_PLACEMENT, 20, ::std::unique_ptr<::tensorflow::GraphOptimizationPass>(new OptimizationExample()), "OptimizationExample")
实际是定义构造了一个OptimizationPassRegistration类型的变量,变量名为register_optimization_xx(xx为一个唯一不重复的数字),这个变量名本身没有看到在哪使用,但是这个过程会调用OptimizationPassRegistration的构造函数,构造函数会注册这个优化器到OptimizationPassRegistry单例对象的group_
OptimizationPassRegistration类的定义及构造函数为:
class OptimizationPassRegistration {
public:
OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,
int phase,
std::unique_ptr<GraphOptimizationPass> pass,
string optimization_pass_name) {
pass->set_name(optimization_pass_name);
OptimizationPassRegistry::Global()->Register(grouping, phase,
std::move(pass));
}
};
会去调用OptimizationPassRegistry::Global()->Register将其注册到group_中
最终在group_[0][20]中增加一个元素,继承自GraphOptimizationPass的对象OptimizationExample类型对象指针
优化器执行
// 输入: grouping 优化路径第n个节点,节点标识,每个节点有多个phase,每个phase多个pass
// 功能: 执行grouping节点的所有优化
Status OptimizationPassRegistry::RunGrouping(
Grouping grouping, const GraphOptimizationPassOptions& options) {
auto group = groups_.find(grouping);
if (group != groups_.end()) {
const uint64 start_us = Env::Default()->NowMicros();
for (auto& phase : group->second) { // 遍历phase
VLOG(1) << "Running optimization phase " << phase.first;
for (auto& pass : phase.second) { // 遍历各个GraphOptimizationPass
VLOG(1) << "Running optimization pass: " << pass->name();
const uint64 pass_start_us = Env::Default()->NowMicros();
Status s = pass->Run(options); // 执行优化
const uint64 pass_end_us = Env::Default()->NowMicros();
metrics::UpdateGraphOptimizationPassTime(pass->name(),
pass_end_us - pass_start_us); // 记录耗时
if (!s.ok()) return s;
if (VLOG_IS_ON(1)) {
if (options.graph) {
DumpGraphToFile(strings::StrCat("after_group_", grouping, "_phase_",
phase.first, "_", pass->name(), "_",
reinterpret_cast<uintptr_t>(
(*options.graph).get())),
**options.graph, options.flib_def);
}
if (options.partition_graphs) {
for (auto& part : *options.partition_graphs) {
DumpGraphToFile(
strings::StrCat(
"after_group_", grouping, "_phase_", phase.first, "_",
pass->name(), "_partition_", part.first, "_",
reinterpret_cast<uintptr_t>(part.second.get())),
*part.second, options.flib_def);
}
}
}
}
}
const uint64 end_us = Env::Default()->NowMicros();
metrics::UpdateGraphOptimizationPassTime("*", end_us - start_us);
}
return Status::OK();
}
优化器举例
tensorflow/core/common_runtime/accumulate_n_optimizer.cc
对op为"AccumulateNV2"的节点进行的一个优化
AccumulateNV2这个op所做的事情跟add_n一样,就是累加多个tensors,但是不需要等所有input都ready之后才开始累加, 文档说明:tensorflow AccumulateNV2
待完善......