tensorflow源码分析之cancellation

192 阅读3分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

 cancellation就是取消的意思。在执行一些耗时操作,或者一个永远不会中断的循环时,当前线程自己无法停止。此时由其他线程来执行取消。

int stop = false;
//线程1执行
while(!stop){ //线程1会永远执行
    //do something
}
//线程2来取消线程1
stop = true

一般在I/O, 永真循环中经常使用这种方式来取消

tensorflow cancellation 用法

tensorflow/core/framework/cancellation.h

看几个测试用例了解其简单用法.

  • new 个CancellationManager
  • get一个token
  • 用token注册回调函数
  • 可以用token取消回调
  • 执行CancellationManager 的StartCancel文法,此时回调函数会被调用

TEST(Cancellation, SimpleCancel) {
  bool is_cancelled = false;
  CancellationManager* manager = new CancellationManager();
  auto token = manager->get_cancellation_token();   //取消前要获取唯一的token
  bool registered = manager->RegisterCallback(
      token, [&is_cancelled]() { is_cancelled = true; }); //注册回调函数,在调用StartCancel时执行
  EXPECT_TRUE(registered);
  manager->StartCancel();  //执行取消
  EXPECT_TRUE(is_cancelled); //变量置为true了。 说明回调函数执行了
  delete manager;
}

TEST(Cancellation, StartCancelTriggersAllCallbacks) {
  bool is_cancelled_1 = false;
  bool is_cancelled_2 = false;
  auto manager = std::make_unique<CancellationManager>();
  auto token_1 = manager->get_cancellation_token();
  EXPECT_TRUE(manager->RegisterCallbackWithErrorLogging(  //注册取消时的回调函数1
      token_1, [&is_cancelled_1]() { is_cancelled_1 = true; }, "TestCallback"));
  auto token_2 = manager->get_cancellation_token();
  EXPECT_TRUE(manager->RegisterCallback(  //注册取消时的回调函数2
      token_2, [&is_cancelled_2]() { is_cancelled_2 = true; }));
  manager->StartCancel(); //都取消

  EXPECT_TRUE(is_cancelled_1);
  EXPECT_TRUE(is_cancelled_2);
}
TEST(Cancellation, CancelMultiple) {
  bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false;
  auto manager = std::make_unique<CancellationManager>();
  auto token_1 = manager->get_cancellation_token();
  bool registered_1 = manager->RegisterCallback(
      token_1, [&is_cancelled_1]() { is_cancelled_1 = true; });
  EXPECT_TRUE(registered_1);
  auto token_2 = manager->get_cancellation_token();
  bool registered_2 = manager->RegisterCallback(
      token_2, [&is_cancelled_2]() { is_cancelled_2 = true; });
  EXPECT_TRUE(registered_2);
  EXPECT_FALSE(is_cancelled_1);
  EXPECT_FALSE(is_cancelled_2);
  manager->StartCancel();
  EXPECT_TRUE(is_cancelled_1);
  EXPECT_TRUE(is_cancelled_2);
  EXPECT_FALSE(is_cancelled_3);
  auto token_3 = manager->get_cancellation_token();
  bool registered_3 = manager->RegisterCallback(
      token_3, [&is_cancelled_3]() { is_cancelled_3 = true; });
  EXPECT_FALSE(registered_3);
  EXPECT_FALSE(is_cancelled_3);
}

一窥源码

从源码中可以看出,CancellationManager可以构成一棵多叉树。在一个父结点中调用StartCancel,会取消此结点的所有子结点。

  1. 构造CancellationManager的时候可以直接传入父结点
  2. 一个CancelationManager也可以调用RegisterChildern来添加子结点

​编辑


typedef int64_t CancellationToken; //token就是个数字

typedef std::function<void()> CancelCallback; //调用StartCancel时回调函数的类型

// This class should never simultaneously be used as the cancellation manager
// for two separate sets of executions (i.e two separate steps, or two separate
// function executions).
class CancellationManager {
 public:
  static const CancellationToken kInvalidToken;
  CancellationManager();

  // Constructs a new CancellationManager that is a "child" of `*parent`.
  //
  // If `*parent` is cancelled, `*this` will be cancelled. `*parent` must
  // outlive the created CancellationManager.
  //创建CancellationManager, parent是父CancellationManager
  //如果父CacnellationManager调用了StartCancel, 那么子也要被取消。
  explicit CancellationManager(CancellationManager* parent);

  ~CancellationManager();

  //执行所有注册了的回调函数
  void StartCancel();

  // Run all callbacks associated with this manager with a status.
  // Currently the status is for logging purpose only. See also
  // CancellationManager::RegisterCallbackWithErrorLogging.
  void StartCancelWithStatus(const Status& status);

  // Returns true iff StartCancel() has been called.
  bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }

  //获取token,原子变量+1
  CancellationToken get_cancellation_token() {
    return next_cancellation_token_.fetch_add(1);
  }

  // Attempts to register the given callback to be invoked when this
  // manager is cancelled. Returns true if the callback was
  // registered; returns false if this manager was already cancelled,
  // and the callback was not registered.
  //
  // If this method returns false, it is the caller's responsibility
  // to perform any cancellation cleanup.
  //
  // This method is tricky to use correctly. The following usage pattern
  // is recommended:
  //
  // class ObjectWithCancellableOperation {
  //   mutex mu_;
  //   void CancellableOperation(CancellationManager* cm,
  //                             std::function<void(Status)> callback) {
  //     bool already_cancelled;
  //     CancellationToken token = cm->get_cancellation_token();
  //     {
  //       mutex_lock(mu_);
  //       already_cancelled = !cm->RegisterCallback(
  //           [this, token]() { Cancel(token); });
  //       if (!already_cancelled) {
  //         // Issue asynchronous operation. Associate the pending operation
  //         // with `token` in some object state, or provide another way for
  //         // the Cancel method to look up the operation for cancellation.
  //         // Ensure that `cm->DeregisterCallback(token)` is called without
  //         // holding `mu_`, before `callback` is invoked.
  //         // ...
  //       }
  //     }
  //     if (already_cancelled) {
  //       callback(errors::Cancelled("Operation was cancelled"));
  //     }
  //   }
  //
  //   void Cancel(CancellationToken token) {
  //     mutex_lock(mu_);
  //     // Take action to cancel the operation with the given cancellation
  //     // token.
  //   }
  //
  // NOTE(mrry): The caller should take care that (i) the calling code
  // is robust to `callback` being invoked asynchronously (e.g. from
  // another thread), (ii) `callback` is deregistered by a call to
  // this->DeregisterCallback(token) when the operation completes
  // successfully, and (iii) `callback` does not invoke any method
  // on this cancellation manager. Furthermore, it is important that
  // the eventual caller of the complementary DeregisterCallback does not
  // hold any mutexes that are required by `callback`.
  bool RegisterCallback(CancellationToken token, CancelCallback callback);

  // Similar to RegisterCallback, but if the cancellation manager starts a
  // cancellation with an error status, it will log the error status before
  // invoking the callback. `callback_name` is a human-readable name of the
  // callback, which will be displayed on the log.
  bool RegisterCallbackWithErrorLogging(CancellationToken token,
                                        CancelCallback callback,
                                        tensorflow::StringPiece callback_name);

 
  bool DeregisterCallback(CancellationToken token);

  bool TryDeregisterCallback(CancellationToken token);

  bool IsCancelling();

 private:
  struct CallbackConfiguration {
    CancelCallback callback;
    std::string name;
    bool log_error = false;
  };

  struct State {
    Notification cancelled_notification;
    gtl::FlatMap<CancellationToken, CallbackConfiguration> callbacks;

    // If this CancellationManager has any children, this member points to the
    // head of a doubly-linked list of its children.
    CancellationManager* first_child = nullptr;  // Not owned.
  };

  bool RegisterCallbackConfig(CancellationToken token,
                              CallbackConfiguration config);

  bool RegisterChild(CancellationManager* child);
  void DeregisterChild(CancellationManager* child);

  bool is_cancelling_;
  std::atomic_bool is_cancelled_;
  std::atomic<CancellationToken> next_cancellation_token_;

  CancellationManager* const parent_ = nullptr;  // Not owned.

  bool is_removed_from_parent_ TF_GUARDED_BY(parent_->mu_) = false;
  //双链表维护子结点
  CancellationManager* prev_sibling_ TF_GUARDED_BY(parent_->mu_) =
      nullptr;  // Not owned.
  CancellationManager* next_sibling_ TF_GUARDED_BY(parent_->mu_) =
      nullptr;  // Not owned.

  mutex mu_;
  std::unique_ptr<State> state_ TF_GUARDED_BY(mu_);
};

// Registers the given cancellation callback, returning a function that can be
// used to deregister the callback. If `cancellation_manager` is NULL, no
// registration occurs and `deregister_fn` will be a no-op.
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
                                    std::function<void()> callback,
                                    std::function<void()>* deregister_fn);

\