[Golang早读] sync.WaitGroup是如何实现的

252 阅读4分钟

什么是sync.WaitGroup

官方文档对其的描述是:WaitGroup等待一组goroutine的任务完成。主goroutine调用添加以设置要等待的goroutine的数量。然后,每个goroutine都会运行并在完成后调用Done。同时,可以使用Wait来阻塞,直到所有goroutine完成。下面是一个简单的使用案例:

func Main() {
    wg := sync.WaitGroup{}
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            time.Sleep(10 * time.Second)
        }()

    }
    wg.Wait() // 等待在此,等所有go func里都执行了Done()才会退出
}

源码分析

type WaitGroup struct {
   noCopy noCopy   //检测使用后不许被copy,只有在使用go vet检查时才能显示错误
   
   state1 uint64
   state2 uint32
}
  • noCapy。其中,noCopy是检测禁止拷贝的技术,WaitGroup在第一次使用后不能再被复制
  • state1字段
    • 高32位为请求计数器counter,代表目前尚未完成的协程个数
    • 低32位为等待计数器waiter,代表目前已调用 Wait 的 goroutine 的个数,因为Wait()方法支持并发,每一次Wait()方法执行,等待计数器就加1
  • state2为信号量

WaitGroup 的整个调用过程可以简单地描述成下面这样:

  1. 当调用WaitGroup.Add(n)时,counter 将会自增: counter + n
  2. 当调用WaitGroup.Wait()时,会将 waiter++。同时调用 runtime_Semacquire(semap), 增加信号量,并挂起当前 goroutine
  3. 当调用 WaitGroup.Done() 时,将会 counter--。如果自减后的 counter 等于 0,说明 WaitGroup 的等待过程已经结束,则需要调用 runtime_Semrelease 释放信号量,唤醒正在 WaitGroup.Wait() 的 goroutine
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
   if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
      // 如果是64位对齐的,什么也不用做
      return &wg.state1, &wg.state2
   } else {
      //如果是32位对齐,数组数组的后两个元素用来做state,可以用来做64位的原子操作,第一个元素32bit用来做信号量
      state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
      return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
   }
}

如果变量是 64 位对齐 (8 byte), 则该变量的起始地址是 8 的倍数。如果变量是 32 位对齐 (4 byte),则该变量的起始地址是 4 的倍数。

当 state1 是 32 位的时候,那么state1被当成是一个数组[3]uint32,数组的第一位是semap信号量,第二三位存储着counter, waiter正好是64位。

image.png

为什么会有这种奇怪的设定呢?这里涉及两个前提:

前提 1:在 WaitGroup 的真实逻辑中, counter 和 waiter 被合在了一起,当成一个 64 位的整数对外使用。当需要变化 counter 和 waiter 的值的时候,也是通过 atomic 来原子操作这个 64 位整数。

前提 2:在 32 位系统下,如果使用 atomic 对 64 位变量进行原子操作,调用者需要自行保证变量的 64 位对齐,否则将会出现异常。golang 的官方文档 sync/atomic/#pkg-note-BUG 原文是这么说的:

在ARM、x86-32和32位MIPS上,调用方负责安排以原子方式访问的64位字的64位对齐。变量或已分配的结构、数组或片中的第一个单词可以依赖于64位对齐

因此,在前提 1 的情况下,WaitGroup 需要对 64 位进行原子操作。根据前提 2,WaitGroup 需要自行保证 count+waiter 的 64 位对齐。

源码部分

Add方法实现

主要操作的state1字段中计数值部分,计数器部分的逻辑主要是通过state(),在上面有提及。每次调用Add方法就会增加相应数量的计数器。如果计数器为零,则释放等待时阻塞的所有goroutine。如果计数器变为负数,请添加恐慌。如果计数器值大于0,说明此时还有任务没有完成,那么调用者就变成等待者,需要加入wait队列,并且阻塞自己。参数可正可负数。如果一个WaitGroup被重用来等待几个独立的事件集,那么新的Add调用必须在所有先前的wait调用返回之后发生。

func (wg *WaitGroup) Add(delta int) {
   //wg.state()返回的是state1中的状态位和信号量
   statep, semap := wg.state()
   if race.Enabled {
      _ = *statep // trigger nil deref early
      if delta < 0 {
         // Synchronize decrements with Wait.
         race.ReleaseMerge(unsafe.Pointer(wg))
      }
      race.Disable()
      defer race.Enable()
   }
   // uint64(delta)<<32 将delta左移32
   // 原子操作修改statep高32位的值,即counter的值
   state := atomic.AddUint64(statep, uint64(delta)<<32)
   
   //右移32位,使高32位变成低32位,得到当前counter计数器的值
   v := int32(state >> 32)
   
   //直接取低32位的值,就是阻塞的waiter数量
   w := uint32(state)
   if race.Enabled && delta > 0 && v == int32(delta) {
      // The first increment must be synchronized with Wait.
      // Need to model this as a read, because there can be
      // several concurrent wg.counter transitions from 0.
      race.Read(unsafe.Pointer(semap))
   }
   //当counter<0,panic
   if v < 0 {
      panic("sync: negative WaitGroup counter")
   }
   //当waiter的数量不为0时,累加后的counter等和delta相等
   //说明Add()和Wait()同时调用了,发生panic
   //正确的方法是先Add()后Wait(),Wait()之后就不允许再Add任务了
   if w != 0 && delta > 0 && v == int32(delta) {
      panic("sync: WaitGroup misuse: Add called concurrently with Wait")
   }
   //Add()调用结束
   if v > 0 || w == 0 {
      return
   }
   
   // 剩下的就是 counter == 0 且 waiter != 0 的情况
   // 在这个情况下,*statep 的值就是 waiter 的值,否则就有问题
   // 在这个情况下,所有的任务都已经完成,可以将 *statep 整个置0
   // 同时向所有的Waiter释放信号量
   if *statep != state {
      panic("sync: WaitGroup misuse: Add called concurrently with Wait")
   }
   // 将状态位重置清0
   *statep = 0
   for ; w != 0; w-- {
      // 首先让信号量加一,然后检查是否有正在等待的Goroutine,如果没有,直接返回;
      // 如果有,调用goready函数唤醒一个Goroutine。
      runtime_Semrelease(semap, false, 0)
   }
}

下面的流程图模拟了Add()步骤(引用自:zhuanlan.zhihu.com/p/365288361…

image.png

Done()方法实现

内部调用Add(-1)的操作

func (wg *WaitGroup) Done() {
   wg.Add(-1)
}
Wait()方法实现

阻塞主goroutine直到WaitGroup计数器变为0

func (wg *WaitGroup) Wait() {
   // 获取waitgroup状态位和信号量
   statep, semap := wg.state()
   if race.Enabled {
      _ = *statep // trigger nil deref early
      race.Disable()
   }
   for {
      //使用原子操作读取statep,是为了保证Add的写入操作已经完成
      state := atomic.LoadUint64(statep)
      //右移高32位,获得counter计数器的值
      v := int32(state >> 32)  
      //直接取值低32位,获得waiter计数器的值
      w := uint32(state)
      //计数器为0,跳出死循环,不用阻塞
      if v == 0 {
         if race.Enabled {
            race.Enable()
            race.Acquire(unsafe.Pointer(wg))
         }
         return
      }
      // 使用CAS操作对waiter、counter计数器进行加1操作
      // 外面有for循环保证可以进行重试
      if atomic.CompareAndSwapUint64(statep, state, state+1) {
         if race.Enabled && w == 0 {
            race.Write(unsafe.Pointer(semap))
         }
         
         // 在这里获取信号量,使线程进入睡眠状态,
         // 与Add方法中runtime_Semrelease增加信号量相对应,
         // 也就是当最后一个任务调用Done方法
         // 后会调用Add方法对counter的值减到0,
         // 就会走到最后的增加信号量
         runtime_Semacquire(semap)
         
         // 在Add方法中增加信号量时已经将statep的值设为0了,
         // 如果这里不是0,说明在wait之后又调用了Add方法,
         // 使用时机不对,触发panic
         if *statep != 0 {
            panic("sync: WaitGroup is reused before previous Wait has returned")
         }
         if race.Enabled {
            race.Enable()
            race.Acquire(unsafe.Pointer(wg))
         }
         return
      }
   }
}

下面的图模拟了Wait()方法的流程(同样引用自:zhuanlan.zhihu.com/p/365288361…

image.png

总结一下,WaitGroup 的原理就五个点:内存对齐,原子操作,counter,waiter,信号量

  • 内存对齐的作用是为了原子操作。

  • counter的增减使用原子操作,counter的作用是一旦为0就释放全部信号量。

  • waiter的自增使用原子操作,waiter的作用是表明要释放多少信号量。