本文已参与「新人创作礼」活动,一起开启掘金创作之路。
什么是Reader
数据读取者,也就是从外部读取训练数据到tensorflow图中。作为一个reader_op可以从外部文件系统,hdfs,数据库等数据源来读取数据。reader是一种resource。reader不是个OP。是给reader_op用的一种资源。
reader接口 Reader不是OP, reader_op才是OP
class QueueInterface;
class ReaderInterface;
// All descendants of this class must be thread-safe.
// 使用reader时要把读取任务编码成一个字符串,比如一个文件名。然后放入到 worker queue中。
// 然后reader从queue里一个个拿出work,开始读取,最后生成 key-value的string作为输出
// 继承者必须是线程安全
class ReaderInterface : public ResourceBase {
public:
// Read a single record into *key / *value. May get more work from
// *queue if the current work is complete. Sets the status on
// *context with an OutOfRange Status if the current work is
// complete and the queue is done (closed and empty).
// This method may block.
//读取一个一个K,V。 如果读取结束给ctx->Status设置OutOfRange. 结束意思是当前worker完成了且Queue也空了。
//这个方法可能会阻塞。
virtual void Read(QueueInterface* queue, tstring* key, tstring* value,
OpKernelContext* context) = 0;
//读取num_records个k, v.
//返回实际读取了多少个
virtual int64_t ReadUpTo(const int64_t num_records, QueueInterface* queue,
std::vector<tstring>* keys,
std::vector<tstring>* value,
OpKernelContext* context) = 0;
// Restore this reader to its newly-constructed state.
virtual Status Reset() = 0;
// Accessors
//已经读取了多少个K,V
virtual int64_t NumRecordsProduced() = 0;
//已经完成了多少个work. 即从Queue里拿出了多少个work
virtual int64_t NumWorkUnitsCompleted() = 0;
//序列化状态,比如读取到了哪里。 由外部保存,后续可恢复
virtual Status SerializeState(tstring* state) = 0;
virtual Status RestoreState(const tstring& state) = 0;
string DebugString() const override { return "a reader"; }
protected:
virtual ~ReaderInterface() {}
};
Reader状态可保存
tensorflow/core/framework/reader_base.proto
message ReaderBaseState {
int64 work_started = 1;
int64 work_finished = 2;
int64 num_records_produced = 3;
bytes current_work = 4;
}
ReaderBase: reader的默认实现
//ReaderInterface的默认实现
class ReaderBase : public ReaderInterface {
public:
explicit ReaderBase(const string& name);
//方法名后面有Locked,说明在调用此方法时ReaderBase里的mutex已经占有了
//在子类中按如下步骤实现-----------------------------------
//ReadLocked返回一个K,V。而且是加锁了的,也就是所有读取都会序列化
// a) 如果一个记录(k,v)读取成功了,设置 *produced = true,并且 *key, *value都设置上
// b) 如果没有记录可读取了,设置 *at_end = true
// c) 如果读取成功一条记录,且能判断后续无可读取记录了,可同时设置*produced 和 *at_end为true
// d) 如果读取时发生错误,return 一个错误的Status。可能会再次尝试来读
virtual Status ReadLocked(tstring* key, tstring* value, bool* produced,
bool* at_end) = 0;
//产生num_records个k,v.具体操作和Read
virtual Status ReadUpToLocked(int64_t num_records, std::vector<tstring>* keys,
std::vector<tstring>* values, int64_t* num_read,
bool* at_end);
// Work开始和结束时调用
virtual Status OnWorkStartedLocked() { return Status::OK(); }
virtual Status OnWorkFinishedLocked() { return Status::OK(); }
// Called to reset the Reader to a newly constructed state.
virtual Status ResetLocked();
// Default implementation generates an Unimplemented error.
// See the protected helper methods below.
virtual Status SerializeStateLocked(tstring* state);
virtual Status RestoreStateLocked(const tstring& state);
// Accessors ----------------------------------------------------------------
// Always true during a call to ReadLocked().
//在ReadLocked()时返回true
bool work_in_progress() const { return work_finished_ < work_started_; }
//当前Work
const tstring& current_work() const { return work_; }
// What was passed to the constructor.
const string& name() const { return name_; }
// Produce the key name (from current_work and the actual key).
tstring KeyName(const tstring& key) const;
protected:
// For descendants wishing to implement serialize & restore state.
// Writes ReaderBase state to *state.
void SaveBaseState(ReaderBaseState* state) const;
// Restores ReaderBase state from state. Assumes state was filled
// using SaveBaseState() above.
Status RestoreBaseState(const ReaderBaseState& state);
private:
//获取下一个工作,其实就是一下个数据文件
virtual string GetNextWorkLocked(QueueInterface* queue,
OpKernelContext* context) const;
//线程安全的读取一个key
void Read(QueueInterface* queue, tstring* key, tstring* value,
OpKernelContext* context) override;
//读取到num_recoreds个key再返回
int64_t ReadUpTo(const int64_t num_records, QueueInterface* queue,
std::vector<tstring>* keys, std::vector<tstring>* value,
OpKernelContext* context) override;
Status Reset() override;
int64_t NumRecordsProduced() override;
int64_t NumWorkUnitsCompleted() override;
Status SerializeState(tstring* state) override;
Status RestoreState(const tstring& state) override;
mutable mutex mu_;
const string name_;
int64_t work_started_ = 0;
int64_t work_finished_ = 0;
int64_t num_records_produced_ = 0;
tstring work_;
};
ReaderOpKernel
tensorflow/core/framework/reader_op_kernel.h
这是reader op,是一种资源OP。会Create资源。可见只是一种资源,会产生资源,并不会产生数据输出。最后是把资源放进ResourceMgr中,输出一个ResourceHandle, 真正使用ResourceHandle来读取数据的代码并不在ReaderOpKernel里。 可以看其父类的ResourceOpKernel::Compute中就能看出
// Implementation for ops providing a Reader.
class ReaderOpKernel : public ResourceOpKernel<ReaderInterface> {
public:
using ResourceOpKernel::ResourceOpKernel;
//可能被子类在调用Compute前就调用
void SetReaderFactory(std::function<ReaderInterface*()> factory)
TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
DCHECK(resource_ == nullptr);
factory_ = factory;
}
void Compute(OpKernelContext* context) override {
if (!IsCancellable()) {
ResourceOpKernel<ReaderInterface>::Compute(context); //调用了父类的Compute
} else {
// Install cancellation
CancellationManager* cm = context->cancellation_manager();
CancellationToken token = cm->get_cancellation_token();
bool already_cancelled =
!cm->RegisterCallback(token, [this]() { this->Cancel(); });
if (!already_cancelled) {
ResourceOpKernel<ReaderInterface>::Compute(context); //调用了父类的Compute
} else {
context->SetStatus(errors::Cancelled("read operation was cancelled"));
}
}
}
private:
virtual bool IsCancellable() const { return false; }
virtual void Cancel() {}
Status CreateResource(ReaderInterface** reader) //真正实现,就是创建一个Reader。由于Factory可能是子类设置进来的,所以具体Reader也能有其他实现。
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
*reader = factory_();
if (*reader == nullptr) {
return errors::ResourceExhausted("Failed to allocate reader");
}
std::function<ReaderInterface*()> temp = nullptr;
factory_.swap(temp);
return Status::OK();
}
std::function<ReaderInterface*()> factory_ TF_GUARDED_BY(mu_);
};