[tensorflow]图优化之GraphOptimizationPass

563 阅读3分钟

tensorflow对图的优化包括MetaOptimizer、GraphOptimizationPass、GraphOptimizer,这里先看GraphOptimizationPass的框架实现

优化器注册

OptimizationPassRegistry

class OptimizationPassRegistry {
 public:
  // Groups of passes are run at different points in initialization.
  enum Grouping {  // 四个阶段
    PRE_PLACEMENT,          // after cost model assignment, before placement.
    POST_PLACEMENT,         // after placement.
    POST_REWRITE_FOR_EXEC,  // after re-write using feed/fetch endpoints.
    POST_PARTITIONING,      // after partitioning
  };

  // Add an optimization pass to the registry.
  void Register(Grouping grouping, int phase,
                std::unique_ptr<GraphOptimizationPass> pass); // 把pass塞到groups_

  const std::map<Grouping, GraphOptimizationPasses>& groups() {
    return groups_;
  }
  // Run all passes in grouping, ordered by phase, with the same
  // options.
  Status RunGrouping(Grouping grouping,
                     const GraphOptimizationPassOptions& options);

  // 全局registry,单例
  static OptimizationPassRegistry* Global();

  // Prints registered optimization passes for debugging.
  void LogGrouping(Grouping grouping, int vlog_level);
  void LogAllGroupings(int vlog_level);

 private:
  // key: 四个group pass阶段
  // value:GraphOptimizationPasses是一个map(int, vector<GraphOptimizationPass>)
  //        key: phase number, 升序排列
  //        value: GraphOptimizationPass指针集合
  std::map<Grouping, GraphOptimizationPasses> groups_; // 优化路径到优化
};

GraphOptimizationPassOptions

作为优化器的输入,保存了图和session等的一些信息

struct GraphOptimizationPassOptions {
  string session_handle; // DirectSession在PRE_PLACEMENT阶段填充
  const SessionOptions* session_options = nullptr;
  const CostModel* cost_model = nullptr;

  FunctionLibraryDefinition* flib_def = nullptr;  // Not owned.
  const DeviceSet* device_set = nullptr;  // Not owned.

  std::unique_ptr<Graph>* graph = nullptr;
  std::unordered_map<string, std::unique_ptr<Graph>>* 
          partition_graphs = nullptr;

  // Indicator of whether or not the graph was derived from a function.
  bool is_function_graph = false;
};

GraphOptimizationPass

优化器基类,优化器都是由它派生而来,添加新的优化器需要继承这个基类

这个类只实现了一个纯虚的Run方法,接收GraphOptimizationPassOptions方法

class GraphOptimizationPass {
 public:
  virtual ~GraphOptimizationPass() {}
  virtual Status Run(const GraphOptimizationPassOptions& options) = 0; // 纯虚
  void set_name(const string& name) { name_ = name; }
  string name() const { return name_; }

 private:
  // The name of the optimization pass, which is the same as the inherited
  // class name.
  string name_;  // 优化器名字
};

注册方式

同op注册一样,OptimizationPass也使用宏定义方式注册

在各具体的优化器文件中调用注册优化器

// 宏__COUNTER__实质上是一个int,并且是具体的数,初值是0,每预编译一次其值自己加1
// 这里保证变量名唯一
#define REGISTER_OPTIMIZATION(grouping, phase, optimization) \
  REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization)

#define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
  REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)

#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)         \
  static ::tensorflow::optimization_registration::OptimizationPassRegistration \
      register_optimization_##ctr(                                             \
          grouping, phase,                                                     \
          ::std::unique_ptr<::tensorflow::GraphOptimizationPass>(              \
              new optimization()),                                             \
          #optimization)

比如,注册一个grouping为PRE_PLACEMENT,phase为20,名字为OptimizationExample的优化器,那么实际生成的代码是:

static ::tensorflow::optimization_registration::OptimizationPassRegistration
register_optimization_0(PRE_PLACEMENT, 20, ::std::unique_ptr<::tensorflow::GraphOptimizationPass>(new OptimizationExample()), "OptimizationExample")

实际是定义构造了一个OptimizationPassRegistration类型的变量,变量名为register_optimization_xx(xx为一个唯一不重复的数字),这个变量名本身没有看到在哪使用,但是这个过程会调用OptimizationPassRegistration的构造函数,构造函数会注册这个优化器到OptimizationPassRegistry单例对象的group_

OptimizationPassRegistration类的定义及构造函数为:

class OptimizationPassRegistration {
 public:
  OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,
                               int phase,
                               std::unique_ptr<GraphOptimizationPass> pass,
                               string optimization_pass_name) {
    pass->set_name(optimization_pass_name);
    OptimizationPassRegistry::Global()->Register(grouping, phase,
                                                 std::move(pass));
  }
};

会去调用OptimizationPassRegistry::Global()->Register将其注册到group_中

最终在group_[0][20]中增加一个元素,继承自GraphOptimizationPass的对象OptimizationExample类型对象指针

优化器执行

// 输入: grouping  优化路径第n个节点,节点标识,每个节点有多个phase,每个phase多个pass
// 功能: 执行grouping节点的所有优化
Status OptimizationPassRegistry::RunGrouping(
    Grouping grouping, const GraphOptimizationPassOptions& options) {
  auto group = groups_.find(grouping);
  if (group != groups_.end()) {
    const uint64 start_us = Env::Default()->NowMicros();
    for (auto& phase : group->second) {  // 遍历phase
      VLOG(1) << "Running optimization phase " << phase.first;
      for (auto& pass : phase.second) {  // 遍历各个GraphOptimizationPass
        VLOG(1) << "Running optimization pass: " << pass->name();
        const uint64 pass_start_us = Env::Default()->NowMicros();
        Status s = pass->Run(options);  // 执行优化
        const uint64 pass_end_us = Env::Default()->NowMicros();
        metrics::UpdateGraphOptimizationPassTime(pass->name(),
                                                 pass_end_us - pass_start_us); // 记录耗时
        if (!s.ok()) return s;
        if (VLOG_IS_ON(1)) {
          if (options.graph) {
            DumpGraphToFile(strings::StrCat("after_group_", grouping, "_phase_",
                                            phase.first, "_", pass->name(), "_",
                                            reinterpret_cast<uintptr_t>(
                                                (*options.graph).get())),
                            **options.graph, options.flib_def);
          }
          if (options.partition_graphs) {
            for (auto& part : *options.partition_graphs) {
              DumpGraphToFile(
                  strings::StrCat(
                      "after_group_", grouping, "_phase_", phase.first, "_",
                      pass->name(), "_partition_", part.first, "_",
                      reinterpret_cast<uintptr_t>(part.second.get())),
                  *part.second, options.flib_def);
            }
          }
        }
      }
    }
    const uint64 end_us = Env::Default()->NowMicros();
    metrics::UpdateGraphOptimizationPassTime("*", end_us - start_us);
  }
  return Status::OK();
}

优化器举例

tensorflow/core/common_runtime/accumulate_n_optimizer.cc

对op为"AccumulateNV2"的节点进行的一个优化

AccumulateNV2这个op所做的事情跟add_n一样,就是累加多个tensors,但是不需要等所有input都ready之后才开始累加, 文档说明:tensorflow AccumulateNV2

待完善......