CyberRT 基础组件

38 阅读16分钟

前言

学习 CyberRT 基础组件。base目录下的代码。

基础组件,但技术不基础。设计到 无锁,编译期判断,functure 等等。
好多都是我不太熟练的,面试经常手撕的代码。

单例宏

DEFINE_TYPE_TRAIT

// 创建一个名为name的类用于判断 T 中是否含有 func 函数
#define DEFINE_TYPE_TRAIT(name, func)                                                              \
    template <typename T>                                                                          \
    struct name {                                                                                  \
        template <typename Class>                                                                  \
        static constexpr bool Test(decltype(&Class::func) *)                                       \
        {                                                                                          \
            return true;                                                                           \
        }                                                                                          \
        template <typename>                                                                        \
        static constexpr bool Test(...)                                                            \
        {                                                                                          \
            return false;                                                                          \
        }                                                                                          \
                                                                                                   \
        static constexpr bool value = Test<T>(nullptr);                                            \
    };                                                                                             \
                                                                                                   \
    template <typename T>                                                                          \
    constexpr bool name<T>::value;

创建一个名为 name 的类 用于判断 T 是否含有 func函数

  • 通过decltype进行类的函数判断
  • 最后是生成了一个 constexpr bool name<T>::value,外部通过 name<T> 传入T快速判断 是否包含 func。

DECLARE_SINGLETON

# 使用上面的宏,得到一个 模板变量 HasShutdown<T>::value
DEFINE_TYPE_TRAIT(HasShutdown, Shutdown)

template <typename T>
typename std::enable_if<HasShutdown<T>::value>::type CallShutdown(T *instance)
{
    instance->Shutdown();
}

template <typename T>
typename std::enable_if<!HasShutdown<T>::value>::type CallShutdown(T *instance)
{
    (void)instance;
}
  • std::enable_if 模板元编程工具,用于在编译时根据条件启用或禁用模板的某个部分。
    • typename std::enable_if<HasShutdown<T>::value>::type
      • std::enable_if<HasShutdown<T>::value>:这是一个模板元编程的条件,它基于 HasShutdown<T>::value的值。如果HasShutdown<T>::value为true,那么这个表达式的结果是std::enable_if的一个特殊的内部类型,否则没有这个内部类型。
      • typename:这是告诉编译器,后面的std::enable_if<HasShutdown<T>::value>::type是一个类型名,而不是一个成员变量或函数。所以,整个表达式的意思是:如果HasShutdown<T>::value为true,则这是一有效的类型;否则,这个表达式没有有效的类型。
#undef UNUSED
#undef DISALLOW_COPY_AND_ASSIGN

#define UNUSED(param) (void)param

/*禁用拷贝构造和 = 赋值*/
#define DISALLOW_COPY_AND_ASSIGN(classname)                                                        \
    classname(const classname &) = delete;                                                         \
    classname &operator=(const classname &) = delete;

/*利用 std::once_flag 和 std::call_once 实现的线程安全的单例宏*/
#define DECLARE_SINGLETON(classname)                                                               \
public:                                                                                            \
    static classname *Instance(bool create_if_needed = true)                                       \
    {                                                                                              \
        static classname *instance = nullptr;                                                      \
        if (!instance && create_if_needed) {                                                       \
            static std::once_flag flag;                                                            \
            std::call_once(flag, [&] { instance = new (std::nothrow) classname(); });              \
        }                                                                                          \
        return instance;                                                                           \
    }                                                                                              \
                                                                                                   \
    static void CleanUp()                                                                          \
    {                                                                                              \
        auto instance = Instance(false);                                                           \
        if (instance != nullptr) {                                                                 \
            CallShutdown(instance);                                                                \
        }                                                                                          \
    }                                                                                              \
                                                                                                   \
private:                                                                                           \
    classname();                                                                                   \
    DISALLOW_COPY_AND_ASSIGN(classname)

  • std::nothrow
    • std::nothrow 是在C++中用于进行内存分配时的一种选项。通常,当你使用new运算符创建对象时,如果内存分配失败,new会抛出std::bad_alloc异常。但是,当你希望在分配失败时不抛出异常,而是返回一个空指针,你可以使用std::nothrow作为参数传递给new。这样,你可以在分配失败时通过检查返回的指针是否为空来处理错误,而不必使用异常处理机制。
  • std::one_flag && std::call_one
  • 按照约定,使用DECLARE_SINGLETON定义单例类的时候,需要提供CallShutdown函数执行删除操作(当然如果没有CallShutdown函数,那么调用上面的CallShutdown执行的是空操作的函数。)

对象池

FOR_EACH

    // 得到一个bool HasLess<T>::value,编译器通过T,得到是否有func
    DEFINE_TYPE_TRAIT(HasLess, operator<) // NOLINT

    template <class Value, class End>
    typename std::enable_if<HasLess<Value>::value && HasLess<End>::value, bool>::type
    LessThan(const Value &val, const End &end)
    {
        return val < end;
    }

    template <class Value, class End>
    typename std::enable_if<!HasLess<Value>::value || !HasLess<End>::value, bool>::type
    LessThan(const Value &val, const End &end)
    {
        return val != end;
    }

#define FOR_EACH(i, begin, end)                                                                    \
    for (auto i = (true ? (begin) : (end)); cyber::base::LessThan(i, (end)); ++i)

ObjectPool

    template <typename T>
    class ObjectPool : public std::enable_shared_from_this<ObjectPool<T>>
    {
    public:
        using InitFunc = std::function<void(T *)>;
        using ObjectPoolPtr = std::shared_ptr<ObjectPool<T>>;

        template <typename... Args>
        explicit ObjectPool(uint32_t num_objects, Args &&...args);

        template <typename... Args>
        ObjectPool(uint32_t num_objects, InitFunc f, Args &&...args);

        virtual ~ObjectPool();

        /* 拿到一个对象*/
        std::shared_ptr<T> GetObject();

    private:
        struct Node {
            T object;
            Node *next;
        };
        ObjectPool(ObjectPool &) = delete;
        ObjectPool &operator=(ObjectPool &) = delete;
        void ReleaseObject(T *object);

        uint32_t num_objects_ = 0;
        char *object_arena_ = nullptr;
        Node *free_head_ = nullptr;
    };
    template <typename T>
    template <typename... Args>
    ObjectPool<T>::ObjectPool(uint32_t num_objects, Args &&...args) : num_objects_(num_objects)
    {
        const size_t size = sizeof(Node);
        object_arena_ = static_cast<char *>(std::calloc(num_objects_, size));

        if (object_arena_ == nullptr) {
            throw std::bad_alloc();
        }

        FOR_EACH(i, 0, num_objects_)
        {
            // placement new在指定内存位置构造T对象
            T *obj = new (object_arena_ + i * size) T(std::forward<Args>(args)...);
            reinterpret_cast<Node *>(obj)->next = free_head_;
            free_head_ = reinterpret_cast<Node *>(obj);
        }
    }

    template <typename T>
    template <typename... Args>
    ObjectPool<T>::ObjectPool(uint32_t num_objects, InitFunc f, Args &&...args)
        : num_objects_(num_objects)
    {
        const size_t size = sizeof(Node);
        object_arena_ = static_cast<char *>(std::calloc(num_objects_, size));
        if (object_arena_ == nullptr) {
            throw std::bad_alloc();
        }

        FOR_EACH(i, 0, num_objects_)
        {
            T *obj = new (object_arena_ + i * size) T(std::forward<Args>(args)...);
            f(obj);
            reinterpret_cast<Node *>(obj)->next = free_head_;
            free_head_ = reinterpret_cast<Node *>(obj);
        }
    }

    template <typename T>
    ObjectPool<T>::~ObjectPool()
    {
        if (object_arena_ != nullptr) {
            const size_t size = sizeof(Node);
            FOR_EACH(i, 0, num_objects_)
            {
                reinterpret_cast<Node *>(object_arena_ + i * size)->object.~T();
            }
            std::free(object_arena_);
        }
    }

    /* 拿到一个对象*/
    template <typename T>
    std::shared_ptr<T> ObjectPool<T>::GetObject()
    {
        if (cyber_unlikely(free_head_ == nullptr)) {
            return nullptr;
        }

        auto self = this->shared_from_this();
        auto obj = std::shared_ptr<T>(reinterpret_cast<T *>(free_head_),
                                      [self](T *object) { self->ReleaseObject(object); });
        free_head_ = free_head_->next;
        return obj;
    }

    template <typename T>
    void ObjectPool<T>::ReleaseObject(T *object)
    {
        if (cyber_unlikely(object == nullptr)) {
            return;
        }
        reinterpret_cast<Node *>(object)->next = free_head_;
        free_head_ = reinterpret_cast<Node *>(object);
    }

总结:(内存分配上是连续的,但是使用上是链表形式)

  • 每个Node节点内部保存Object T,每个节点指向Node* next,组成链表
  • 预先分配一整块连续内存,然后通过placement new在这些内存上,创建Node。
  • 分配Object的时候,返回shared_ptr(绑定对应的归还fun)。

无锁 读写锁

无锁算法 CAS原子操作

这块,写过好多次,但每次都忘了。这次好好记记

ReadLockGuard WriteLockGuard

ReadLockGuard WriteLockGuard 作为 typename RWLock 的友元类,RAII管理锁。

template <typename RWLock>
    class ReadLockGuard
    {
    public:
        explicit ReadLockGuard(RWLock &lock) : rw_lock_(lock) { rw_lock_.ReadLock(); }

        ~ReadLockGuard() { rw_lock_.ReadUnlock(); }

    private:
        ReadLockGuard(const ReadLockGuard &other) = delete;
        ReadLockGuard &operator=(const ReadLockGuard &other) = delete;
        RWLock &rw_lock_;
    };

    template <typename RWLock>
    class WriteLockGuard
    {
    public:
        explicit WriteLockGuard(RWLock &lock) : rw_lock_(lock) { rw_lock_.WriteLock(); }

        ~WriteLockGuard() { rw_lock_.WriteUnlock(); }

    private:
        WriteLockGuard(const WriteLockGuard &other) = delete;
        WriteLockGuard &operator=(const WriteLockGuard &other) = delete;
        RWLock &rw_lock_;
    };

AtomicRWLock

class AtomicRWLock
    {
        friend class ReadLockGuard<AtomicRWLock>;
        friend class WriteLockGuard<AtomicRWLock>;

    public:
        static const int32_t RW_LOCK_FREE = 0; // 标志位:代表此时没有线程持有锁,可读可写
        static const int32_t WRITE_EXCLUSIVE = -1; // 标志位:代表此时锁被一个写线程获取
        static const uint32_t MAX_RETRY_TIMES = 5; // 获取锁的连续重试次数,连续失败则让出线程执行权
        AtomicRWLock() {}
        explicit AtomicRWLock(bool write_first) : write_first_(write_first) {}

    private:
        // all these function only can used by ReadLockGuard/WriteLockGuard;
        void ReadLock();
        void WriteLock();

        void ReadUnlock();
        void WriteUnlock();

        AtomicRWLock(const AtomicRWLock &) = delete;
        AtomicRWLock &operator=(const AtomicRWLock &) = delete;
        std::atomic<uint32_t> write_lock_wait_num_ = {0}; // 等待写操作的的线程数
        std::atomic<int32_t> lock_num_ = {0};             // 持有锁的线程数
        bool write_first_ = true;                         // 默认写优先
    };
  • RW_LOCK_FREE = 0; // 标志位:代表此时没有线程持有锁,可读可写
  • WRITE_EXCLUSIVE = -1; // 标志位:代表此时锁被一个写线程获取
  • MAX_RETRY_TIMES = 5; // 获取锁的连续重试次数,连续失败则让出线程执行权
  • std::atomic<uint32_t> write_lock_wait_num_ = {0}; // 等待写操作的的线程数
  • std::atomic<int32_t> lock_num_ = {0}; // 锁的状态(-1写锁占用 >=0读锁持有数量 =0锁空闲)
  • bool write_first_ = true; // 默认写优先

ReadLock

inline void AtomicRWLock::ReadLock()
    {
        uint32_t retry_times = 0;
        int32_t lock_num = lock_num_.load();
        if (write_first_) {
            do {
                // 写优先,只要有等待写的线程,等待。
                while (lock_num < RW_LOCK_FREE || write_lock_wait_num_.load() > 0) {
                    if (++retry_times == MAX_RETRY_TIMES) {
                        // saving cpu
                        std::this_thread::yield();
                        retry_times = 0;
                    }
                    lock_num = lock_num_.load();
                }
            } while (!lock_num_.compare_exchange_weak(
                lock_num, lock_num + 1, std::memory_order_acq_rel, std::memory_order_relaxed));
        } else {
            do {
                while (lock_num < RW_LOCK_FREE) {
                    if (++retry_times == MAX_RETRY_TIMES) {
                        // saving cpu
                        std::this_thread::yield();
                        retry_times = 0;
                    }
                    lock_num = lock_num_.load();
                }
            } while (!lock_num_.compare_exchange_weak(
                lock_num, lock_num + 1, std::memory_order_acq_rel, std::memory_order_relaxed));
        }
    }

读锁支持共享访问,多个读线程可同时持有。

  • 写优先模式(write_first_=true):
    1. 先检查锁状态(lock_num_是否≥0,即无写锁占用)和写等待线程数(write_lock_wait_num_是否为 0)。
    2. 重试 5 次未满足条件时,调用yield()让出 CPU,避免忙等消耗。
    3. 通过compare_exchange_weak原子操作,将当前锁计数 + 1,成功则获取读锁。
  • 非写优先模式(write_first_=false):
    1. 仅检查锁状态(lock_num_≥0,无写锁占用),不关心写等待线程。
    2. 其余重试、原子操作逻辑与写优先模式一致。

WriteLock

inline void AtomicRWLock::WriteLock()
    {
        int32_t rw_lock_free = RW_LOCK_FREE;
        uint32_t retry_times = 0;
        write_lock_wait_num_.fetch_add(1);
        while (!lock_num_.compare_exchange_weak(
            rw_lock_free, WRITE_EXCLUSIVE, std::memory_order_acq_rel, std::memory_order_relaxed)) {
            rw_lock_free = RW_LOCK_FREE;
            if (++retry_times == MAX_RETRY_TIMES) {
                // saving cpu
                std::this_thread::yield();
                retry_times = 0;
            }
        }
        write_lock_wait_num_.fetch_sub(1);
    }

写锁支持排他访问,同一时间仅一个写线程可持有。

  1. 先通过fetch_add(1)原子递增写等待线程数,标记当前线程正在等待写锁。
  2. 循环通过compare_exchange_weak,尝试将锁状态从 “空闲(0)” 改为 “写锁占用(-1)”。
  3. 重试 5 次未成功时,调用yield()让出 CPU。
  4. 成功获取写锁后,通过fetch_sub(1)原子递减写等待线程数。

Unlock

  • 读锁解锁(ReadUnlock):直接通过fetch_sub(1)原子递减读锁计数,无需额外检查(原子操作保证线程安全)。
  • 写锁解锁(WriteUnlock):通过fetch_add(1)将锁状态从 -1 改为 0,释放锁为空闲状态。

无锁 HashMap

Entry

/**
        1. 创建 Entry,并设置 value_ptr
        2. 访问、销毁 entry,并读取 value_ptr

        release 保证 之前的所有写操作都在该操作之前完成  写入最新的数据
        acquire 保证 之后的所有读操作都在该操作之后完成  读到最新的数据(其他线程通过release同步的结果)
        */
        struct Entry {
            Entry() {}
            explicit Entry(K key) : key(key)
            {
                value_ptr.store(new V(), std::memory_order_release);
            }
            Entry(K key, const V &value) : key(key)
            {
                value_ptr.store(new V(value), std::memory_order_release);
            }
            Entry(K key, V &&value) : key(key)
            {
                value_ptr.store(new V(std::forward<V>(value)), std::memory_order_release);
            }
            ~Entry() { delete value_ptr.load(std::memory_order_acquire); }

            K key = 0;
            std::atomic<V *> value_ptr = {nullptr};
            std::atomic<Entry *> next = {nullptr};
        };

Bucket

        class Bucket
        {
        public:
            Bucket() : head_(new Entry()) {}
            ~Bucket()
            {
                Entry *ite = head_;
                while (ite) {
                    auto tmp = ite->next.load(std::memory_order_acquire);
                    delete ite;
                    ite = tmp;
                }
            }

            bool Has(K key)
            {
                Entry *m_target = head_->next.load(std::memory_order_acquire);
                while (Entry *target = m_target) {
                    if (target->key < key) {
                        m_target = target->next.load(std::memory_order_acquire);
                        continue;
                    } else {
                        return target->key == key;
                    }
                }
                return false;
            }

            // 查找 key 对应的 指针和前一个指针
            bool Find(K key, Entry **prev_ptr, Entry **target_ptr)
            {
                Entry *prev = head_;
                Entry *m_target = head_->next.load(std::memory_order_acquire);
                while (Entry *target = m_target) {
                    if (target->key == key) {
                        *prev_ptr = prev;
                        *target_ptr = target;
                        return true;
                    } else if (target->key > key) {
                        *prev_ptr = prev;
                        *target_ptr = target;
                        return false;
                    } else {
                        prev = target;
                        m_target = target->next.load(std::memory_order_acquire);
                    }
                }
                *prev_ptr = prev;
                *target_ptr = nullptr;
                return false;
            }

            void Insert(K key, const V &value)
            {
                Entry *prev = nullptr;
                Entry *target = nullptr;
                Entry *new_entry = nullptr;
                V *new_value = nullptr;
                while (true) {
                    if (Find(key, &prev, &target)) {
                        // key exists, update value
                        if (!new_value) {
                            new_value = new V(value);
                        }
                        auto old_val_ptr = target->value_ptr.load(std::memory_order_acquire);
                        if (target->value_ptr.compare_exchange_strong(old_val_ptr, new_value,
                                                                      std::memory_order_acq_rel,
                                                                      std::memory_order_relaxed)) {
                            delete old_val_ptr;
                            if (new_entry) {
                                delete new_entry;
                                new_entry = nullptr;
                            }
                            return;
                        }
                        continue;
                    } else {
                        if (!new_entry) {
                            new_entry = new Entry(key, value);
                        }
                        new_entry->next.store(target, std::memory_order_release);
                        if (prev->next.compare_exchange_strong(target, new_entry,
                                                               std::memory_order_acq_rel,
                                                               std::memory_order_relaxed)) {
                            // Insert success
                            if (new_value) {
                                delete new_value;
                                new_value = nullptr;
                            }
                            return;
                        }
                        // another entry has been inserted, retry
                    }
                }
            }

           ......

            bool Get(K key, V **value)
            {
                Entry *prev = nullptr;
                Entry *target = nullptr;
                if (Find(key, &prev, &target)) {
                    *value = target->value_ptr.load(std::memory_order_acquire);
                    return true;
                }
                return false;
            }

            Entry *head_; // 每个哈希桶头部
        };

AtomicHashMap

template <typename K, typename V, std::size_t TableSize = 128,
              typename std::enable_if<
                  std::is_integral<K>::value && (TableSize & (TableSize - 1)) == 0, int>::type = 0>
    class AtomicHashMap
    {
    public:
        AtomicHashMap() : capacity_(TableSize), mode_num_(capacity_ - 1) {}
        AtomicHashMap(const AtomicHashMap &other) = delete;
        AtomicHashMap &operator=(const AtomicHashMap &other) = delete;
        bool Has(K key)
        {
            uint64_t index = key & mode_num_;
            return table_[index].Has(key);
        }

        bool Get(K key, V **value)
        {
            uint64_t index = key & mode_num_;
            return table_[index].Get(key, value);
        }

        bool Get(K key, V *value)
        {
            uint64_t index = key & mode_num_;
            V *val = nullptr;
            bool res = table_[index].Get(key, &val);
            if (res) {
                *value = *val;
            }
            return res;
        }

        void Set(K key)
        {
            uint64_t index = key & mode_num_;
            table_[index].Insert(key);
        }

        void Set(K key, const V &value)
        {
            uint64_t index = key & mode_num_;
            table_[index].Insert(key, value);
        }

        void Set(K key, V &&value)
        {
            uint64_t index = key & mode_num_;
            table_[index].Insert(key, std::forward<V>(value));
        }

private:
        Bucket table_[TableSize]; // 哈希数组
        uint64_t capacity_;
        uint64_t mode_num_;
    };
  • (TableSize & (TableSize - 1)) == 0 表示 TableSize 要求是 2 的幂
  • typename std::enable_if< std::is_integral<K>::value && (TableSize & (TableSize - 1)) == 0, int>::type = 0>
    • 表示 K 是一个数,TableSize 是一个 2 的幂,满足条件,这个模板参数才有效(类型为int,默认值为 0)

无界 无锁队列

template <typename T>
    class UnboundedQueue
    {
    public:
        UnboundedQueue() { Reset(); }
        UnboundedQueue &operator=(const UnboundedQueue &other) = delete;
        UnboundedQueue(const UnboundedQueue &other) = delete;

        ~UnboundedQueue() { Destroy(); }

        void Clear()
        {
            Destroy();
            Reset();
        }

        void Enqueue(const T &element)
        {
            auto node = new Node();
            node->data = element;
            Node *old_tail = tail_.load();

            while (true) {
                if (tail_.compare_exchange_strong(old_tail, node)) {
                    old_tail->next = node;
                    old_tail->release();
                    size_.fetch_add(1);
                    break;
                }
            }
        }

        bool Dequeue(T *element)
        {
            Node *old_head = head_.load();
            Node *head_next = nullptr;
            do {
                head_next = old_head->next;

                if (head_next == nullptr) {
                    return false;
                }
            } while (!head_.compare_exchange_strong(old_head, head_next));
            *element = head_next->data;
            size_.fetch_sub(1);
            old_head->release();
            return true;
        }

        size_t Size() { return size_.load(); }

        bool Empty() { return size_.load() == 0; }

    private:
        struct Node {
            T data;
            std::atomic<uint32_t> ref_count;
            Node *next = nullptr;
            Node() { ref_count.store(2); }
            void release()
            {
                ref_count.fetch_sub(1);
                if (ref_count.load() == 0) {
                    delete this;
                }
            }
        };

        void Reset()
        {
            auto node = new Node();
            head_.store(node);
            tail_.store(node);
            size_.store(0);
        }

        void Destroy()
        {
            auto ite = head_.load();
            Node *tmp = nullptr;
            while (ite != nullptr) {
                tmp = ite->next;
                delete ite;
                ite = tmp;
            }
        }

        std::atomic<Node *> head_;
        std::atomic<Node *> tail_;
        std::atomic<size_t> size_;
    };
  • 无界 无锁队列,通过链表节点Node组织。

有界 无锁队列

template <typename T>
    class BoundedQueue
    {
    public:
        using value_type = T;
        using size_type = uint64_t;

    public:
        BoundedQueue() {}
        BoundedQueue &operator=(const BoundedQueue &other) = delete;
        BoundedQueue(const BoundedQueue &other) = delete;
        ~BoundedQueue();
        bool Init(uint64_t size);
        bool Init(uint64_t size, WaitStrategy *strategy);
        bool Enqueue(const T &element);
        bool Enqueue(T &&element);
        bool WaitEnqueue(const T &element);
        bool WaitEnqueue(T &&element);
        bool Dequeue(T *element);
        bool WaitDequeue(T *element);
        uint64_t Size();
        bool Empty();
        void SetWaitStrategy(WaitStrategy *WaitStrategy);
        void BreakAllWait();
        uint64_t Head() { return head_.load(); }
        uint64_t Tail() { return tail_.load(); }
        uint64_t Commit() { return commit_.load(); }

    private:
        uint64_t GetIndex(uint64_t num);

        alignas(CACHELINE_SIZE) std::atomic<uint64_t> head_ = {0};
        alignas(CACHELINE_SIZE) std::atomic<uint64_t> tail_ = {1};
        alignas(CACHELINE_SIZE) std::atomic<uint64_t> commit_ = {1};
        // alignas(CACHELINE_SIZE) std::atomic<uint64_t> size_ = {0};
        uint64_t pool_size_ = 0;
        T *pool_ = nullptr;
        std::unique_ptr<WaitStrategy> wait_strategy_ = nullptr;
        volatile bool break_all_wait_ = false;
    };
  • 有界无锁队列,采用顺序存储结构实现。一个数组,直接通过T* 作为T类型的数组访问,其中 head_,tail_ 作为头尾的下标。
  • commit_ 用于协调生产者和消费者,标记队列中最后一个已完全写入,可被消费者安全读取的元素位置。(只有commit_之前的位置才是生产者已完成数据写入、消费者可安全读取的。)⭐
template <typename T>
    BoundedQueue<T>::~BoundedQueue()
    {
        if (wait_strategy_) {
            BreakAllWait();
        }
        if (pool_) {
            for (uint64_t i = 0; i < pool_size_; ++i) {
                pool_[i].~T();
            }
            std::free(pool_);
        }
    }

    /* 默认线程阻塞策略为睡眠策略 */
    template <typename T>
    inline bool BoundedQueue<T>::Init(uint64_t size)
    {
        return Init(size, new SleepWaitStrategy());
    }

    /* 指定队列大小和线程阻塞策略 */
    template <typename T>
    bool BoundedQueue<T>::Init(uint64_t size, WaitStrategy *strategy)
    {
        // Head and tail each occupy a space
        pool_size_ = size + 2;
        pool_ = reinterpret_cast<T *>(std::calloc(pool_size_, sizeof(T)));
        if (pool_ == nullptr) {
            return false;
        }
        for (uint64_t i = 0; i < pool_size_; ++i) {
            new (&(pool_[i])) T();
        }
        wait_strategy_.reset(strategy);
        return true;
    }

    template <typename T>
    bool BoundedQueue<T>::Enqueue(const T &element)
    {
        uint64_t new_tail = 0;
        uint64_t old_commit = 0;
        uint64_t old_tail = tail_.load(std::memory_order_acquire);
        do {
            new_tail = old_tail + 1;
            if (GetIndex(new_tail) == GetIndex(head_.load(std::memory_order_acquire))) {
                return false;
            }
        } while (!tail_.compare_exchange_weak(old_tail, new_tail, std::memory_order_acq_rel,
                                              std::memory_order_relaxed));
        pool_[GetIndex(old_tail)] = element;
        do {
            old_commit = old_tail;
        } while (cyber_unlikely(!commit_.compare_exchange_weak(
            old_commit, new_tail, std::memory_order_acq_rel, std::memory_order_relaxed)));
        wait_strategy_->NotifyOne();
        return true;
    }

    template <typename T>
    bool BoundedQueue<T>::Enqueue(T &&element)
    {
        uint64_t new_tail = 0;
        uint64_t old_commit = 0;
        uint64_t old_tail = tail_.load(std::memory_order_acquire);
        do {
            new_tail = old_tail + 1;
            if (GetIndex(new_tail) == GetIndex(head_.load(std::memory_order_acquire))) {
                return false;
            }
        } while (!tail_.compare_exchange_weak(old_tail, new_tail, std::memory_order_acq_rel,
                                              std::memory_order_relaxed));
        pool_[GetIndex(old_tail)] = std::move(element);
        do {
            old_commit = old_tail;
        } while (cyber_unlikely(!commit_.compare_exchange_weak(
            old_commit, new_tail, std::memory_order_acq_rel, std::memory_order_relaxed)));
        wait_strategy_->NotifyOne();
        return true;
    }

    template <typename T>
    bool BoundedQueue<T>::Dequeue(T *element)
    {
        uint64_t new_head = 0;
        uint64_t old_head = head_.load(std::memory_order_acquire);
        do {
            new_head = old_head + 1;
            if (new_head == commit_.load(std::memory_order_acquire)) {
                return false;
            }
            *element = pool_[GetIndex(new_head)];
        } while (!head_.compare_exchange_weak(old_head, new_head, std::memory_order_acq_rel,
                                              std::memory_order_relaxed));
        return true;
    }

    /*基于等待策略的入队操作*/
    template <typename T>
    bool BoundedQueue<T>::WaitEnqueue(const T &element)
    {
        while (!break_all_wait_) {
            if (Enqueue(element)) {
                return true;
            }
            if (wait_strategy_->EmptyWait()) {
                continue;
            }
            // wait timeout
            break;
        }

        return false;
    }

    /*基于等待策略的出队操作*/
    template <typename T>
    bool BoundedQueue<T>::WaitEnqueue(T &&element)
    {
        while (!break_all_wait_) {
            if (Enqueue(std::move(element))) {
                return true;
            }
            if (wait_strategy_->EmptyWait()) {
                continue;
            }
            // wait timeout
            break;
        }

        return false;
    }

    template <typename T>
    bool BoundedQueue<T>::WaitDequeue(T *element)
    {
        while (!break_all_wait_) {
            /*如果对了里有数据,则直接return true,否则返回false*/
            if (Dequeue(element)) {
                return true;
            }
            /*执行等待策略*/
            if (wait_strategy_->EmptyWait()) {
                continue;
            }
            // wait timeout
            break;
        }

        return false;
    }

    template <typename T>
    inline uint64_t BoundedQueue<T>::Size()
    {
        return tail_ - head_ - 1;
    }

    template <typename T>
    inline bool BoundedQueue<T>::Empty()
    {
        return Size() == 0;
    }

    /* 由于是无符号整数,所以返回的是索引,类似于取余*/
    template <typename T>
    inline uint64_t BoundedQueue<T>::GetIndex(uint64_t num)
    {
        return num - (num / pool_size_) * pool_size_; // faster than %
    }

    template <typename T>
    inline void BoundedQueue<T>::SetWaitStrategy(WaitStrategy *strategy)
    {
        wait_strategy_.reset(strategy);
    }

    template <typename T>
    inline void BoundedQueue<T>::BreakAllWait()
    {
        break_all_wait_ = true;
        wait_strategy_->BreakAllWait();
    }
  • 有趣的 GetIndex, 通过 整除 减法,替代 取余操作。
  • Init 初始化操作,calloc 分配 size+2(包括首尾) 个 T 大小的内存,并置为0。reinterpret_cast<T*> 从 void* 转换为 T*;new (&(pool_[i])) T(); 使用 placement new,在指定的内存地址上构造对象。默认:head_ = 0, tail_ = 1,commit_ = 1
  • 生产者在入队时,步骤如下:
    • 先通过tail_抢占下一个写入位置(new_tail = old_tail + 1),确保多个生产者不会竞争同一个位置;
    • 向抢占到的位置(old_tail对应的索引)写入数据;
    • 最后通过commit_的原子更新(commit_ = new_tail),宣告 “old_tail位置的数据已写入完成”。
      此时commit_的作用是:告诉消费者 “到commit_为止的位置已完成数据写入,可安全读取”
  • 消费者在出队时,步骤如下:
    • 检查下一个待读取位置(new_head = old_head + 1);
    • new_head == commit_,说明没有已完成写入的数据(所有已抢占的位置可能还在写入中),此时无法读取;
    • new_head < commit_,说明new_head位置的数据已被生产者写入完成(因为commit_已确认),可以安全读取。
      此时commit_的作用是:作为消费者读取的 “安全边界”,避免消费者读取到未完成写入的数据(例如某个生产者已抢占位置但尚未写完)
  • WaitEnqueue 基于等待策略的入队操作:
    • 如果队列未满,则立即插入返回,否则进入空等状态。

commit_ 的示意图: image.png

等待策略

  • WaitStrategy 基类
  • BlockWaitStrategy 子类实现,阻塞等待,互斥锁+条件变量(mutex + condition_variable)
  • SleepWaitStrategy 子类实现,休眠时,默认休眠一段时间。
  • YieldWaitStrategy 子类实现,std::this_thread::yield(); 当前线程放弃其时间片
  • BusySpinWaitStrategy 子类实现,忙等,return true一直等待
  • TimeoutBlockWaitStrategy 子类实现,cv_.wait_for(lock, time_out_) 允许线程等待一段时间,如果超时return false,反之 return true。如果超时时间内没有被唤醒,返回false。如果被唤醒,返回true。

线程池

    class ThreadPool
    {
    public:
        explicit ThreadPool(std::size_t thread_num, std::size_t max_task_num = 1000);

        template <typename F, typename... Args>
        auto Enqueue(F &&f, Args &&...args)
            -> std::future<typename std::result_of<F(Args...)>::type>;

        ~ThreadPool();

    private:
        std::vector<std::thread> workers_;
        BoundedQueue<std::function<void()>> task_queue_;
        std::atomic_bool stop_;
    };
  • std::thread 数组
  • 有界无锁 队列 作为任务队列
  • stop_ 是否关闭
    /*构造函数入参为 线程数量和最大任务数量*/
    inline ThreadPool::ThreadPool(std::size_t threads, std::size_t max_task_num) : stop_(false)
    {
        /*创建一个BoundedQueue,采用的等待策略是阻塞策略*/
        if (!task_queue_.Init(max_task_num, new BlockWaitStrategy())) {
            throw std::runtime_error("Task queue init failed.");
        }

        /* 初始化线程池 创建空的任务,每个任务都是一个while循环 */
        workers_.reserve(threads);
        for (size_t i = 0; i < threads; ++i) {
            workers_.emplace_back([this] {
                while (!stop_) {
                    /*返回值为空的可调用对象*/
                    std::function<void()> task;
                    if (task_queue_.WaitDequeue(&task)) {
                        /*如果出队成功,说明领取到了任务,则就去执行此任务*/
                        task();
                    }
                }
            });
        }
    }

    // before using the return value, you should check value.valid()
    template <typename F, typename... Args>
    auto ThreadPool::Enqueue(F &&f, Args &&...args)
        -> std::future<typename std::result_of<F(Args...)>::type>
    {
        using return_type = typename std::result_of<F(Args...)>::type;

        auto task = std::make_shared<std::packaged_task<return_type()>>(
            std::bind(std::forward<F>(f), std::forward<Args>(args)...));

        std::future<return_type> res = task->get_future();

        // don't allow enqueueing after stopping the pool
        if (stop_) {
            return std::future<return_type>();
        }
        task_queue_.Enqueue([task]() { (*task)(); });
        return res;
    };

    // the destructor joins all threads
    /* 唤醒线程池里所有线程,然后等待所有子线程执行完毕,释放资源*/
    inline ThreadPool::~ThreadPool()
    {
        if (stop_.exchange(true)) {
            return;
        }
        task_queue_.BreakAllWait();
        for (std::thread &worker : workers_) {
            worker.join();
        }
    }
  • ThreadPool::Enqueue 两个模板参数,
    • F 可调用对象
    • Args 可调用对象的实参
    • 函数返回 std::future<typename std::result_of<F(Args...)>::type>
      • std::result_of<F(Args...)>::type 推导出 调用 F(Args) 后的返回值类型 return_type
  • std::packaged_task<return_type()>>(std::bind(std::forward<F>(f), std::forward<Args>(args)...); 对可调用对象的封装
    • std::packaged_task 可以封装任何调用目标,从而用于实现异步调用
      • return_type() 是 std::packaged_task 的模板参数,代表封装的是一个 return_type() 类型的可调用对象。
      • std::packaged_task 可以 通过 get_future() 获取future对象。
        • future 通过operator()则执行task
        • task的执行结果就存储在futrue
        • future 通过 reset() 重复使用 packaged_task,获取新的 futrue对象