tensorflow源码分析之reader

301 阅读3分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

 什么是Reader

数据读取者,也就是从外部读取训练数据到tensorflow图中。作为一个reader_op可以从外部文件系统,hdfs,数据库等数据源来读取数据。reader是一种resource。reader不是个OP。是给reader_op用的一种资源。

reader接口 Reader不是OP, reader_op才是OP


class QueueInterface;
class ReaderInterface;


// All descendants of this class must be thread-safe.
// 使用reader时要把读取任务编码成一个字符串,比如一个文件名。然后放入到 worker queue中。
// 然后reader从queue里一个个拿出work,开始读取,最后生成 key-value的string作为输出
// 继承者必须是线程安全
class ReaderInterface : public ResourceBase {
 public:
  // Read a single record into *key / *value.  May get more work from
  // *queue if the current work is complete.  Sets the status on
  // *context with an OutOfRange Status if the current work is
  // complete and the queue is done (closed and empty).
  // This method may block.
  //读取一个一个K,V。 如果读取结束给ctx->Status设置OutOfRange. 结束意思是当前worker完成了且Queue也空了。
 //这个方法可能会阻塞。
  virtual void Read(QueueInterface* queue, tstring* key, tstring* value,
                    OpKernelContext* context) = 0;

  //读取num_records个k, v. 
  //返回实际读取了多少个
  virtual int64_t ReadUpTo(const int64_t num_records, QueueInterface* queue,
                           std::vector<tstring>* keys,
                           std::vector<tstring>* value,
                           OpKernelContext* context) = 0;

  // Restore this reader to its newly-constructed state.
  virtual Status Reset() = 0;

  // Accessors
  //已经读取了多少个K,V
  virtual int64_t NumRecordsProduced() = 0;
  //已经完成了多少个work. 即从Queue里拿出了多少个work
  virtual int64_t NumWorkUnitsCompleted() = 0;

  //序列化状态,比如读取到了哪里。 由外部保存,后续可恢复
  virtual Status SerializeState(tstring* state) = 0;
  virtual Status RestoreState(const tstring& state) = 0;

  string DebugString() const override { return "a reader"; }

 protected:
  virtual ~ReaderInterface() {}
};

Reader状态可保存

tensorflow/core/framework/reader_base.proto

message ReaderBaseState {
  int64 work_started = 1;
  int64 work_finished = 2;
  int64 num_records_produced = 3;
  bytes current_work = 4;
}

ReaderBase: reader的默认实现


//ReaderInterface的默认实现
class ReaderBase : public ReaderInterface {
 public:
  explicit ReaderBase(const string& name);
  //方法名后面有Locked,说明在调用此方法时ReaderBase里的mutex已经占有了
 
  //在子类中按如下步骤实现-----------------------------------

  //ReadLocked返回一个K,V。而且是加锁了的,也就是所有读取都会序列化
  //  a) 如果一个记录(k,v)读取成功了,设置 *produced = true,并且 *key, *value都设置上
  //  b) 如果没有记录可读取了,设置 *at_end = true
  //  c) 如果读取成功一条记录,且能判断后续无可读取记录了,可同时设置*produced 和 *at_end为true
  //  d) 如果读取时发生错误,return 一个错误的Status。可能会再次尝试来读
  
  virtual Status ReadLocked(tstring* key, tstring* value, bool* produced,
                            bool* at_end) = 0;

  //产生num_records个k,v.具体操作和Read
  virtual Status ReadUpToLocked(int64_t num_records, std::vector<tstring>* keys,
                                std::vector<tstring>* values, int64_t* num_read,
                                bool* at_end);

  // Work开始和结束时调用
  virtual Status OnWorkStartedLocked() { return Status::OK(); }
  virtual Status OnWorkFinishedLocked() { return Status::OK(); }

  // Called to reset the Reader to a newly constructed state.
  virtual Status ResetLocked();

  // Default implementation generates an Unimplemented error.
  // See the protected helper methods below.
  virtual Status SerializeStateLocked(tstring* state);
  virtual Status RestoreStateLocked(const tstring& state);

  // Accessors ----------------------------------------------------------------

  // Always true during a call to ReadLocked().
  //在ReadLocked()时返回true
  bool work_in_progress() const { return work_finished_ < work_started_; }

  //当前Work
  const tstring& current_work() const { return work_; }

  // What was passed to the constructor.
  const string& name() const { return name_; }

  // Produce the key name (from current_work and the actual key).
  tstring KeyName(const tstring& key) const;

 protected:
  // For descendants wishing to implement serialize & restore state.

  // Writes ReaderBase state to *state.
  void SaveBaseState(ReaderBaseState* state) const;

  // Restores ReaderBase state from state. Assumes state was filled
  // using SaveBaseState() above.
  Status RestoreBaseState(const ReaderBaseState& state);

 private:
  //获取下一个工作,其实就是一下个数据文件
  virtual string GetNextWorkLocked(QueueInterface* queue,
                                   OpKernelContext* context) const;

  //线程安全的读取一个key
  void Read(QueueInterface* queue, tstring* key, tstring* value,
            OpKernelContext* context) override;

  //读取到num_recoreds个key再返回
  int64_t ReadUpTo(const int64_t num_records, QueueInterface* queue,
                   std::vector<tstring>* keys, std::vector<tstring>* value,
                   OpKernelContext* context) override;

  Status Reset() override;
  int64_t NumRecordsProduced() override;
  int64_t NumWorkUnitsCompleted() override;
  Status SerializeState(tstring* state) override;
  Status RestoreState(const tstring& state) override;

  mutable mutex mu_;
  const string name_;
  int64_t work_started_ = 0;
  int64_t work_finished_ = 0;
  int64_t num_records_produced_ = 0;
  tstring work_;
};

ReaderOpKernel

tensorflow/core/framework/reader_op_kernel.h

这是reader op,是一种资源OP。会Create资源。可见只是一种资源,会产生资源,并不会产生数据输出。最后是把资源放进ResourceMgr中,输出一个ResourceHandle, 真正使用ResourceHandle来读取数据的代码并不在ReaderOpKernel里。 可以看其父类的ResourceOpKernel::Compute中就能看出


// Implementation for ops providing a Reader.
class ReaderOpKernel : public ResourceOpKernel<ReaderInterface> {
 public:
  using ResourceOpKernel::ResourceOpKernel;

  //可能被子类在调用Compute前就调用
  void SetReaderFactory(std::function<ReaderInterface*()> factory)
      TF_LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    DCHECK(resource_ == nullptr);
    factory_ = factory;
  }

  void Compute(OpKernelContext* context) override {
    if (!IsCancellable()) {
      ResourceOpKernel<ReaderInterface>::Compute(context); //调用了父类的Compute
    } else {
      // Install cancellation
      CancellationManager* cm = context->cancellation_manager();
      CancellationToken token = cm->get_cancellation_token();
      bool already_cancelled =
          !cm->RegisterCallback(token, [this]() { this->Cancel(); });

      if (!already_cancelled) {
        ResourceOpKernel<ReaderInterface>::Compute(context);  //调用了父类的Compute
      } else {
        context->SetStatus(errors::Cancelled("read operation was cancelled"));
      }
    }
  }

 private:
  virtual bool IsCancellable() const { return false; }
  virtual void Cancel() {}

  Status CreateResource(ReaderInterface** reader) //真正实现,就是创建一个Reader。由于Factory可能是子类设置进来的,所以具体Reader也能有其他实现。
      TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
    *reader = factory_();
    if (*reader == nullptr) {
      return errors::ResourceExhausted("Failed to allocate reader");
    }
    std::function<ReaderInterface*()> temp = nullptr;
    factory_.swap(temp);
    return Status::OK();
  }

  std::function<ReaderInterface*()> factory_ TF_GUARDED_BY(mu_);
};