SnowFlake 雪花算法详解与C++实现

449 阅读4分钟

背景

现在是大数据时代,互联网每天都在产生海量的数据。大数据量会导致我们在数据库存储数据的时候,需要分库分表,从而保证数据库查询的性能。对于水平分表就需要保证表中 id 的全局唯一性。

在使用Redis实现分布式锁的时候,也需要指定用户的唯一ID。

对于 MySQL 而言,一个表中的主键 id 一般使用自增的方式,但是如果进行水平分表之后,多个表中会生成重复的 id 值。那么如何保证水平分表后的多张表中的 id 是全局唯一性的呢?

如果还是借助数据库主键自增的形式,那么可以让不同表初始化一个不同的初始值,然后按指定的步长进行自增。例如有3张拆分表,初始主键值为1,2,3,自增步长为3。

当然也有人使用 UUID 来作为主键,但是 UUID 生成的是一个无序的字符串,对于 MySQL 推荐使用增长的数值类型值作为主键来说不适合。因为MySQL的默认存储引擎是InnoDB,底层采用的是B+树作为索引的数据结构。对于聚簇索引(主键索引)来说,B+树的底层叶子节点存储的是索引号和具体的数据,这些索引号和数据是按照主键的大小按照顺序排列的。如果主键ID不是递增的话,那么在插入新数据的时候基会导致叶子节点的分裂,会影响性能。如果主键是递增的,那么新增的数据只需要追加到原有的数据后面,叶子节点不需要分裂。所以UUID不太合适,会影响性能。

当然还有其他解决方案,但是雪花算法是最广泛被使用的一个用于解决分布式 id 的高效方案,也是许多互联网公司在推荐使用的。

snowflake算法来源于Twitter,使用scala语言实现,在GitHub上就能找到开源代码,感兴趣可以去看看。

SnowFlake 雪花算法

image.png

算法的实现思想也不难:

雪花算法原理就是生成一个的64位比特位的 long 类型的唯一 id。

  • 最高1位固定值0,因为生成的 id 是正整数,如果是1就是负数了。
  • 接下来41位存储毫秒级时间戳,2^41/(1000606024365)=69,大概可以使用69年。
  • 再接下10位存储机器码,包括5位 datacenterId 和5位 workerId。最多可以部署2^10=1024台机器。
  • 最后12位存储序列号。同一毫秒时间戳时,通过这个递增的序列号来区分。即对于同一台机器而言,同一毫秒时间戳下,可以生成2^12=4096个不重复 id。

可以将雪花算法作为一个单独的服务进行部署,然后需要全局唯一 id 的系统,请求雪花算法服务获取 id 即可。

对于每一个雪花算法服务,需要先指定10位的机器码,这个根据自身业务进行设定即可。例如机房号+机器号,机器号+服务号,或者是其他可区别标识的10位比特位的整数值都行。

雪花算法优点

雪花算法有以下几个优点:

  • 高并发分布式环境下生成不重复 id,每秒可生成百万个不重复 id。
  • 基于时间戳,以及同一时间戳下序列号自增,基本保证 id 有序递增。
  • 算法简单,在内存中进行,效率高。

C++代码实现

Github上也有人用C++做了snowflake最基本的实现,这里直接查看源码:

#include <cstdint>
#include <chrono>
#include <stdexcept>
#include <mutex>

class snowflake_nonlock
{
public:
    void lock()
    {
    }
    void unlock()
    {
    }
};

template<int64_t Twepoch, typename Lock = snowflake_nonlock>
class snowflake
{
    using lock_type = Lock;
    static constexpr int64_t TWEPOCH = Twepoch;
    static constexpr int64_t WORKER_ID_BITS = 5L;
    static constexpr int64_t DATACENTER_ID_BITS = 5L;
    static constexpr int64_t MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1;
    static constexpr int64_t MAX_DATACENTER_ID = (1 << DATACENTER_ID_BITS) - 1;
    static constexpr int64_t SEQUENCE_BITS = 12L;
    static constexpr int64_t WORKER_ID_SHIFT = SEQUENCE_BITS;
    static constexpr int64_t DATACENTER_ID_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS;
    static constexpr int64_t TIMESTAMP_LEFT_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS + DATACENTER_ID_BITS;
    static constexpr int64_t SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1;

    using time_point = std::chrono::time_point<std::chrono::steady_clock>;

    time_point start_time_point_ = std::chrono::steady_clock::now();
    int64_t start_millsecond_ = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();

    int64_t last_timestamp_ = -1;
    int64_t workerid_ = 0;
    int64_t datacenterid_ = 0;
    int64_t sequence_ = 0;
    lock_type lock_;
public:
    snowflake() = default;

    snowflake(const snowflake&) = delete;

    snowflake& operator=(const snowflake&) = delete;

    void init(int64_t workerid, int64_t datacenterid)
    {
        if (workerid > MAX_WORKER_ID || workerid < 0) {
            throw std::runtime_error("worker Id can't be greater than 31 or less than 0");
        }

        if (datacenterid > MAX_DATACENTER_ID || datacenterid < 0) {
            throw std::runtime_error("datacenter Id can't be greater than 31 or less than 0");
        }

        workerid_ = workerid;
        datacenterid_ = datacenterid;
    }

    int64_t nextid()
    {
        std::lock_guard<lock_type> lock(lock_);
        //std::chrono::steady_clock  cannot decrease as physical time moves forward
        auto timestamp = millsecond();
        if (last_timestamp_ == timestamp)
        {
            sequence_ = (sequence_ + 1)&SEQUENCE_MASK;
            if (sequence_ == 0)
            {
                timestamp = wait_next_millis(last_timestamp_);
            }
        }
        else
        {
            sequence_ = 0;
        }

        last_timestamp_ = timestamp;

        return ((timestamp - TWEPOCH) << TIMESTAMP_LEFT_SHIFT)
            | (datacenterid_ << DATACENTER_ID_SHIFT)
            | (workerid_ << WORKER_ID_SHIFT)
            | sequence_;
    }

private:
    int64_t millsecond() const noexcept
    {
        auto diff = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start_time_point_);
        return start_millsecond_ + diff.count();
    }

    int64_t wait_next_millis(int64_t last) const noexcept
    {
        auto timestamp = millsecond();
        while (timestamp <= last)
        {
            timestamp = millsecond();
        }
        return timestamp;
    }
};