Go 如何实现一个线程安全的Map?

462 阅读8分钟

引言

Go内置的Map是非线程安全的,在并发场景下使用会导致数据异常,甚至引发painc,为了解决这个问题,很多项目中都需要自己实现线程安全的Map。如何实现?其实不论什么Java、Go还是其它什么语言,在实际应用中通常有三种实现思路:

  1. 在使用Map时加一把锁,以此保证临界资源的安全
  2. 对Map加读写锁,根据场景将读写数据分开,减少使用锁的时间
  3. Map加锁时并不锁住所有的临界资源,将资源分段精细化管理,每次只锁住访问资源所在的分段,减少锁住的数据范围

事实上,Go在1.9之后新增的sync.Map就是使用的第二种思想实现的线程安全的Map,但是sync.Map更适合在读多写少的场景,在写多的场景性能很差,感兴趣的可以查询相关资料看看。本文精力放在使用分段思想实现一个线程安全的锁。

本文实现思路是参照concurrent-Map,可以看做是它的源码学习。

分段锁

分段锁思想是一种并发控制技术,它将共享资源划分成多个不重叠的片段,并对每个片段进行独立的加锁,通过这种方式可以减小锁粒度,提高系统的并发性能。举个容易理解的例子:以前大家都在游乐场的大门口排队,排到你了你进去玩你想玩的项目,现在把大门拆了,所有人在各自想玩的项目门前排队,这样是不是效率高很多?

实现

有了上边的指导思想,我们接下来就开始实现一个简单的现场安全的锁。

Segment

我们需要一个数组来表示多个分段,每个分段对应一把锁,分段内部使用go内置的map和读写锁sync.RWMutex,每个分段自持的锁来控制分段内部访问达到整个锁的线程安全的目的。先来定一个分段Segment

type Segment struct {
   sync.RWMutex                        //继承读写锁,保证访问内部map的安全
   kvs          map[string]interface{} //分段内部使用map
}

那我们实现的线程安全锁自然就是 type ConcurrentMap []*Segment

接下来是初始化ConcurrentMap的方法

func NewConcurrentMap() ConcurrentMap {
   //根据默认值创建多个分段
   segments := make([]*Segment, DEFAULT_SEGMENT_COUNT)

   for i := range segments {
      segments[i] = &Segment{
         kvs: make(map[string]interface{}),
      }
   }
   return segments

}

根据配置的分段数量,为map分配分段所需空间,注意不要忘记为每个分段的map分配空间

getSegment

接下来要实现Map中几个常用的方法比如:Put、Get、Size、Remove、Pop,但是在这之前,需要先实现一个根据key定位segment的方法,由于我们把key人为的分到多个分段中,哪个key应该存放在哪个segment中,需要制定一个映射寻址规则,这是存放、查询数据的基础。

func (c ConcurrentMap) getSegment(key string) *Segment {
   s := c[getHash(key)%DEFAULT_SEGMENT_COUNT]
   return s
}

func getHash(key string) int {
	//这里用key的长度代替key对应的hash,仅做说明问题
	return len(key)
}

这里使用key的长度对Segment数量取模来确定这个key应当存放在哪个segment。

注意,这里仅做了一个非常简单的实现用以说明问题,实际应用不能这么干,这个映射算法一定要尽可能的把key均匀散落在各个segment中,感兴趣的去看看相关实现,这里不是本文重点不再赘述。

Put和Get

//Put
func (c ConcurrentMap) Put(key, value string) {
   seg := c.getSegment(key)
   //锁住这个分段,防止被并发修改
   seg.Lock()
   seg.kvs[key] = value
   seg.Unlock()
}

//Get
func (c ConcurrentMap) Get(key string) (bool, interface{}) {
   seg := c.getSegment(key)
   //读锁锁住这个分段,防止读过程中被修改
   seg.RLock()
   v, ok := seg.kvs[key]
   //释放锁
   seg.RUnlock()
   return ok, v
}

Put()时,先通过key值寻址定位到对应的segment,让后将key和value写入到segment内部维护的map中,注意这里可能存在并发写入,需要加一个写锁。Get()的实现类似,区别并不需要写锁,读锁就可以了(读写锁:读写互斥、写写互斥、读读不互斥)。

Size

func (c ConcurrentMap) Size() int {

   //遍历每个分段,累加每个分段的元素个数
   var count = 0
   for i := 0; i < DEFAULT_SEGMENT_COUNT; i++ {
      seg := c[i]
      //计算分段中元素个数时,加个读锁防止被修改
      seg.RLock()
      count += len(seg.kvs)
      seg.RUnlock()
   }

   return count
}

ConcurrentMap并没有维护一个属性来记录map中元素个数,原因可能是维护一个属性,获取总个数倒很方便,但是设计到元素增减,都要加锁进行修改这个值,反而会影响性能。这里获取ConcurrentMap元素个数是通过累加每个megment中存放的元素个数。

Remove和Pop

func (c ConcurrentMap) Remove(key string) {

   seg := c.getSegment(key)
   //移除seg中的元素,加锁
   seg.Lock()
   //使用内置delete函数删除map中的元素
   delete(seg.kvs, key)
   seg.Unlock()

}

// Pop 获取并移除key
func (c ConcurrentMap) Pop(key string) interface{} {
   seg := c.getSegment(key)
   seg.Lock()
   v, ok := seg.kvs[key]
   if ok {
      //存在key,则删除
      delete(seg.kvs, key)
   }
   seg.Unlock()
   return v
}

Remove()移除元素,实现也很简单,先定位segment,然后从内部map中将key及其元素移除,注意这里同样需要加写锁。Pop() 获取并移除key,可以理解为Get()Remove()组合,实现直接看代码吧。

keys

func (c ConcurrentMap) Keys() []string {

   //需要知道有多少个key,用来确定chan大小
   size := c.Size()
   keysChan := make(chan string, size)

   wg := sync.WaitGroup{}
   wg.Add(len(c))

   go func() {
      //每个segment一个协程遍历获取key,并发提高性能
      for i, seg := range c {
         go func(i int, seg *Segment) {
            seg.RLock()
            for k := range seg.kvs {
               keysChan <- k
            }
            wg.Done()
            seg.RUnlock()
         }(i, seg)
      }
      //等待所有的segmeng遍历完再关闭keysChan
      wg.Wait()
      //使用完keysChan关闭,以防死锁
      close(keysChan)
   }()

   re := make([]string, 0, size)

   //阻塞监听,等待消费
   for k := range keysChan {
      re = append(re, k)
   }
   return re
}

Keys()获取ConcurrentMap中所有key,这个实现相对复杂一点,复杂来自于我们想提高性能。由于segment相对独立,我们可以同时遍历所有的segment中key,最后在汇总在一起,这样要比挨个遍历segment获取所有key效率大幅提高,是不是?

同时遍历segment这个好解决,开协程就完事了,但是最后还要将所有segment得到的key汇总,这个怎么实现效率好一些?比较直接容易想到的方式是,将每个segment遍历的结果放到一个临时数组,最终将临时结果再遍历汇总,但是这要等待所有segment遍历完成,然后再用同步的方式汇总,性能差点意思!那肯定有人说,就使用一个数组,所有segment的遍历到的key都写到一个数组中,可以避免最后的汇总过程,但是往一个数组中写势必要加锁,也会带来开销!还有没有更好的办法呢?

说到这里,使用chan来解决这个问题呼之欲出,将所有并发遍历segment的得到的key写到chan中,可以避免直接加锁,并且进入chan的key可以马上被写入到汇总的数组中,无需等待所有的key被遍历,这相当于汇总的过程和遍历segment的过程是同时进行的,这样的效率肯定高啊!下面看一下具体实现的要点。

  1. 要开一个带有缓冲的chan用于存放key,缓冲的大小使用ConcurrentMap总元素个数(也未必需要这么大,因为chan中元素边进边出)
  2. 为每个segment开一个协程遍历key,将key放入chan,遍历key时要加读锁
  3. 消费chan中数据,将key汇总到一个数组,注意,这个过程和步骤2一定要在不同的协程中,这样才能达到汇总的过程和遍历segment的过程是同时进行的目的
  4. 当所有的segment都遍历完毕,要关闭chan,为了实现这一点引入sync.WaitGroup{},保证segment都遍历完再关闭chan。

Iter

func (c ConcurrentMap) Iter() chan KvPair {
   //需要知道有多少个key,用来确定chan大小
   size := c.Size()
   keysChan := make(chan KvPair, size)

   wg := sync.WaitGroup{}
   wg.Add(len(c))

   go func() {
      //每个segment一个协程遍历获取key,并发提高性能
      for i, seg := range c {
         go func(i int, seg *Segment) {
            seg.RLock()
            for k, v := range seg.kvs {
               keysChan <- KvPair{k, v}
            }
            wg.Done()
            seg.RUnlock()
         }(i, seg)
      }
      //等待所有的segmeng遍历完再关闭keysChan
      wg.Wait()
      //使用完keysChan关闭,以防死锁
      close(keysChan)
   }()

   return keysChan
}

// KvPair 封装键值对
type KvPair struct {
	Key string
	Val interface{}
}

Iter()遍历map中的kv是一个很常用的操作,实现方式和Keys()很像,返回结果是一个KvPair类型的chan,KvPair是封装的键值对。为什么不返回一个KvPair数组?原因和Keys()中提到的汇总过程一样,为了调用者更高效的遍历kv,而不用等到所有kv对都取出来后,调用者再进行遍历使用。

总结

线程安全的Map使用场景非常多,不同的语言会结合自己的语言特性来实现,但常见的实现思路就这几种,通过本文我们可以学会分段锁的思想。另外Keys()Iter()的实现源码和设计很值得学习、借鉴!