【Go并发编程】WaitGroup源码阅读

45 阅读3分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 20 天,点击查看活动详情

WaitGroup

WaitGroup是Go语言中的一个并发原语,用于协调一组goroutine的执行。它主要包括三个方法:

  • Add(delta int): 计数器加上delta,可以为负数。
  • Done(): 计数器减1,相当于Add(-1)。
  • Wait(): 阻塞当前goroutine,直到计数器减为0。

WaitGroup常用于等待一组goroutine全部完成后再执行下一步操作。在使用WaitGroup时,需要注意以下几点:

  • Add方法应该在所有goroutine开始执行之前调用。
  • 在goroutine中执行Done方法以标记完成。
  • 在需要等待所有goroutine完成时,调用Wait方法来阻塞主goroutine。

当计数器为0时,Wait方法会立即返回,否则会一直阻塞,直到计数器归零。如果计数器的值变为负数,Wait方法会panic。

源码阅读

type WaitGroup struct {
   noCopy noCopy
   state1 uint64
   state2 uint32
}
  • noCopy 的辅助字段,主要就是辅助 vet 工具检查是否通过 copy 赋值这个 WaitGroup 实例。
  • state1、state2 存储了waiter 数、WaitGroup 计数值、信号量。
// 返回statep, semap指针
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
   if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
      // state1 is 64-bit aligned: nothing to do.
      return &wg.state1, &wg.state2
   } else {
      // state1 is 32-bit aligned but not 64-bit aligned: this means that
      // (&state1)+4 is 64-bit aligned.
      state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
      return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
   }
}
  • 根据64位和32位返回statep, semap指针
  • statep高32位代表计数值,低32位代表waiter数量
  • semap是信号量,用于唤醒waiter

这种做法的目的是为了在 32 位机器上使用 sync.WaitGroup 时,保证访问计数器的时候是原子的。因为 sync.WaitGroup 内部的计数器是 64 位的,而在 32 位机器上,64 位的访问通常不是原子的,因此需要将计数器拆分成两个 32 位的计数器,并使用信号量来同步访问这两个计数器。

func (wg *WaitGroup) Add(delta int) {
   statep, semap := wg.state()
   // 高32bit是计数值v,所以把delta左移32,增加到计数上
   state := atomic.AddUint64(statep, uint64(delta)<<32)
   // v:计数值,w:waiter数
   v := int32(state >> 32)
   w := uint32(state)
   // 计数值不能为负数
   if v < 0 {
      panic("sync: negative WaitGroup counter")
   }
   // add和wait不能并发调用
   if w != 0 && delta > 0 && v == int32(delta) {
      panic("sync: WaitGroup misuse: Add called concurrently with Wait")
   }
   // 没有等待者,正常返回
   if v > 0 || w == 0 {
      return
   }

   if *statep != state {
      panic("sync: WaitGroup misuse: Add called concurrently with Wait")
   }
   // 如果计数值v为0并且waiter的数量w不为0,那么state的值就是waiter的数量
   // 将waiter的数量设置为0,因为计数值v也是0,所以它们俩的组合*statep直接设置为0即可。此时需要并唤醒所有的waiter
   *statep = 0
   for ; w != 0; w-- {
      runtime_Semrelease(semap, false, 0)
   }
}

func (wg *WaitGroup) Done() {
   wg.Add(-1)
}

ADD方法:

  1. 计数值加delta
  2. 如果计数值<0,panic
  3. 存在与wait并发调用,panic
  4. v>0或者没有waiter,直接返回
  5. v=0,唤醒所有等待者

Done方法:

  • 相当于ADD(-1)
func (wg *WaitGroup) Wait() {
   statep, semap := wg.state()

   for {
      state := atomic.LoadUint64(statep)
      v := int32(state >> 32)
      w := uint32(state)
      // 计数值为0直接返回
      if v == 0 {
         return
      }
      // 计数不为0,等待者数加一
      if atomic.CompareAndSwapUint64(statep, state, state+1) {
         // 阻塞
         runtime_Semacquire(semap)
         // 唤醒,退出
         if *statep != 0 {
            panic("sync: WaitGroup is reused before previous Wait has returned")
         }
         return
      }
   }
}

Wait方法:

  1. 计数值=0,无需等待,直接返回
  2. waiter数量加一,阻塞
  3. 被Done唤醒,退出阻塞

总结

整体通过计数和信号量,实现waiter的阻塞与唤醒。并且代码中加强了各自异常情况的panic,一定要正常使用wg。