Go-线程安全map学习笔记

212 阅读4分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第21天,点击查看活动详情

线程不安全的map

总所周知,Go中内置容器Map是线程不安全的,当有多个协程同时进行操作,会抛出fatal error: concurrent map writes的错误,例如下面的例子

func testMap() {
    m := make(map[string]int)
    for i := 0; i < 100; i++ {
        go func(num int) {
            m[fmt.Sprintf("%d", num)] = num
        }(i)
    }
}

因此如果我们有多个协程同时操作map的需求,就不能直接使用map啦。通常我们可以通过加读写锁(sync.RWMutex)或者使用sync.Map来保证线程安全。

使用读写锁的map

type RWSafeMap struct {
    sync.RWMutex // 使用读写锁
    m map[string]int
}
​
func NewRWSafeMap(len int) *RWSafeMap {
    return &RWSafeMap{
        m: make(map[string]int, len),
    }
}
​
func (m *RWSafeMap) Store(key string, value int) {
    m.Lock()
    defer m.Unlock()
    m.m[key] = value
}
​
func (m *RWSafeMap) Load(key string) (int, bool) {
    m.RLock()
    defer m.RUnlock()
    value, ok := m.m[key]
    return value, ok
}
​
func (m *RWSafeMap) Delete(key string) {
    m.Lock()
    defer m.Unlock()
    delete(m.m, key)
}
​
func (m *RWSafeMap) Len() int {
    m.RLock()
    defer m.RUnlock()
    return len(m.m)
}
​
func testRWSafeMap() {
    num := 100
    safeMap := NewRWSafeMap(num)
    waitGroup := sync.WaitGroup{}
    waitGroup.Add(num)
    for i := 0; i < num; i++ {
        go func(num int) {
            defer waitGroup.Done()
            safeMap.Store(fmt.Sprintf("%d", num), num)
        }(i)
    }
    waitGroup.Wait()
    fmt.Println(safeMap.Len()) // 100
}
​

可以看到通过添加读写锁,RWSafeMap已经能够满足线程安全的的需求了,但同样因为加锁而导致性能下降。Go还提供了Sync.Map这个线程安全的map,它适合读多写少的场景,在大量的读的场景下,速度可以比通过添加读写锁快一倍,下面就来看看它的相关API吧

sync.Map

func testSyncMap() {
    var syncMap sync.Map
    waitGroup := sync.WaitGroup{}
    waitGroup.Add(100)
    for i := 0; i < 100; i++ {
        go func(num int) {
            defer waitGroup.Done()
            syncMap.Store(fmt.Sprintf("%d", num), num) // 添加键值对
        }(i)
    }
    waitGroup.Wait()
    var sum = 0
    // 遍历map
    syncMap.Range(func(key, value any) bool {
        str := key.(string)
        num := value.(int)
        fmt.Printf("key=%s\tvalue=%d\n", str, num)
        sum++
        return true
    })
    fmt.Println(sum)
​
    // 通过key获得value
    value, ok := syncMap.Load("50")
    fmt.Println(value, ok)
    // 删除键值对
    syncMap.Delete("1")
}
​

那么sync.Map是通过什么来保证线程安全的呢?

它的结构体如下

type Map struct {
    mu Mutex // 互斥锁、用来保护 read和dirty字段
    
    read atomic.Value // 本质上也是个原生的map,支持并发的读取,其他操作则通过加锁来保证线程安全
​
    dirty map[any]*entry // 包含需要加锁才能访问的元素,包括在read字段中不存在的键值对
​
    misses int // 记录从read中miss的次数,当miss数和dirty长度一样时,则会把dirty提升为read
}
​
type readOnly struct {
    m       map[any]*entry
    amended bool // dirty是否包含read不存在的数据
}
​
// entry代表一个值
type entry struct {
    p unsafe.Pointer // *interface{}
}

store方法

func (m *Map) Store(key, value any) {
    read, _ := m.read.Load().(readOnly) // 获得只读map
    if e, ok := read.m[key]; ok && e.tryStore(&value) { // 如果对应key存在、则尝试通过CAS更新值
        return
    }
    // 否则加锁,进行操作
    m.mu.Lock()
    read, _ = m.read.Load().(readOnly)
    if e, ok := read.m[key]; ok {
        if e.unexpungeLocked() {
            // 如果已经该键值对已被删除则添加到dirty中
            m.dirty[key] = e
        }
        e.storeLocked(&value) // 更新值
    } else if e, ok := m.dirty[key]; ok { // 如果dirty存在该键值对则直接更新
        e.storeLocked(&value) 
    } else {
        // 如果都不存在则是新key
        if !read.amended {
            // 创建一个dirty
            m.dirtyLocked()
            m.read.Store(readOnly{m: read.m, amended: true})
        }
        m.dirty[key] = newEntry(value) // 将值存放进dirty中
    }
    m.mu.Unlock()
}
​
  1. 先读取read判断键是否存在,如果存在则尝试使用CAS更新值。
  2. 如果键不存在或者CAS更新失败则直接加锁并继续往下走。
  3. 如果该键值对之前被标记删除,那么先将这个键值对写到 dirty 中,同时更新 read,或者如果 dirty 中已经有这一项了,直接更新 read。
  4. 如果是一个新键值对,dirty为空则通过read复制创建一个新的dirty,随后将新值注入

Load方法

func (m *Map) Load(key any) (value any, ok bool) {
    read, _ := m.read.Load().(readOnly)
    e, ok := read.m[key]
    if !ok && read.amended { // 如果read中不存在该键值对,同时dirty有额外的值
        m.mu.Lock()
        read, _ = m.read.Load().(readOnly) // 双重判断一下,提高性能
        e, ok = read.m[key]
        if !ok && read.amended { 
            e, ok = m.dirty[key]
            m.missLocked() // 核心方法
        }
        m.mu.Unlock()
    }
    if !ok {
        return nil, false
    }
    return e.load()
}
​
func (m *Map) missLocked() {
    m.misses++
    if m.misses < len(m.dirty) { 
        return
    }
    // 如果misses等于dirty的长度,则将dirty提升为read,并将原来的dirty设为nil
    m.read.Store(readOnly{m: m.dirty})
    m.dirty = nil
    m.misses = 0
}
​
​
  1. 先在read中尝试寻找键值对,如果找到则直接返回
  2. 如果不存在且dirty有额外的值,则从dirty尝试寻找
  3. 如果misses值已经大于等于dirty的长度,则将dirty变为read,以提升查找效率

小小总结

sync.Map就是通过读写分离的方式来实现,对于读/更新等操作都尽量在不加锁的read中进行操作,对于写则会使用加锁的操作dirty来实现。

同时如果read中未命中次数太多,则会动态调整,将dirty提升为read,以提升查找效率。因此官方只推荐在读多写少的情况下使用sync.Map,如果增删较多还是使用读写锁来保证map的线程把~

参考资料