Go并发编程 — sync.WaitGroup

1,157 阅读2分钟

简介

WaitGroup 可以解决一个 goroutine 等待多个 goroutine 同时结束的场景,常见的场景例如启动了多个worker goruntine 进行并发处理,然后某个goruntine需要汇总信息。

使用

WaitGroup 使用比较简单,主要有下面这3个方法。

func (wg *WaitGroup) Add(delta int)
func (wg *WaitGroup) Done()
func (wg *WaitGroup) Wait()
  • Add方法:增加WaitGroup的计数值
  • Done方法:减少WaitGroup的计数值,表示goruntine完成了,内部实现就是调用了Add(-1)
  • Wait方法:需要等待的goruntine可以调用Wait进行阻塞,直到WaitGroup的计数值变为0

看一下下面这个Demo,启动了3个 goruntine 来做 worker ,注意 Add 方法需要提前设置,我这里是在for循环里面设置的,然后完成之后需要调用 Done 方法,表示 goruntine 处理结束了,最后在 main goruntine 中调用 Wait 方法来等待 3 个 worker goruntine 处理完成。

package main

import (
   "fmt"
   "sync"
   "time"
)

func main() {

   var wg sync.WaitGroup

   for i := 1; i <= 3; i++ {
      wg.Add(1)
      go func(index int) {
         defer wg.Done()
         fmt.Printf("【goruntine#%d】开始工作\n", index)
         time.Sleep(time.Second * 2)
         fmt.Printf("【goruntine#%d】结束工作\n", index)
      }(i)
   }

   wg.Wait()
   fmt.Printf("所有goruntine全部结束,可以处理数据了")
}

输出结果:

【goruntine#3】开始工作
【goruntine#1】开始工作
【goruntine#2】开始工作
【goruntine#3】结束工作
【goruntine#1】结束工作              
【goruntine#2】结束工作              
所有goruntine全部结束,可以处理数据了

实现

数据结构

来看一下 WaitGroup 的数据结构吧,主要由 noCopy 和 state1 组成。

  • noCopy:保证 vet 工具检查是否 copy 复制这个 WaitGroup 实例
  • state1:是一个复合字段,存储了 WaitGroup 的 counter(计数值),waiter数量和信号量
type WaitGroup struct {
   noCopy noCopy

   state1 [3]uint32
}

在来看一下内部的一个 state 方法,主要是获取 计数值、waiter数量和信号量的方法。有没有发现这里面还有一段判断逻辑,由于 atomic 后续需要进行 64 位的操作(拿到statep的返回值,进行原子操作),需要 64 位的内存对齐,但是在 32 位的机器上是不能保证 64 位的内存对齐的。

在 64 位环境下,state1 的第一个元素是 waiter 数,第二个元素是 WaitGroup 的计数值,第三个元素是信号量。

在 32 位环境下,如果 state1 不是 64 位对齐的地址,那么 state1 的第一个元素是信号量,后两个元素分别是 waiter 数和计数值,来保证 statep 是对齐的。

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
   if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
      return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
   } else {
      return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
   }
}

Add方法

我这里删除了部分 race 检查的代码,主要逻辑首先增加counter值,由于delta可以传负数,所以如果counter值为0,需要唤醒等待者。

func (wg *WaitGroup) Add(delta int) {
   statep, semap := wg.state() // 获取到statep和semap
   state := atomic.AddUint64(statep, uint64(delta)<<32) // 将delta左移32位,然后进行原子加操作
   v := int32(state >> 32) // 获取增加后的counter值
   w := uint32(state) // 获取waiter值

   // counter不能小于0
   if v < 0 {
      panic("sync: negative WaitGroup counter")
   }
   
   // waiter值不等0的情况下,delta > 0 && v == int32(delta)表示counter是第一次增加
   // 表示Add方法和Wait存在并发调用,也就是复用Waiter的时候需要waiter值变成0才行
   if w != 0 && delta > 0 && v == int32(delta) {
      panic("sync: WaitGroup misuse: Add called concurrently with Wait")
   }
   
   // counter大于0,或者没有等待者,直接返回
   if v > 0 || w == 0 {
      return
   }
   
   // 避免并发调用 add 和 wait
   if *statep != state {
      panic("sync: WaitGroup misuse: Add called concurrently with Wait")
   }
   
   // 走到这里的话,表示counter是0,waiter值不是0,表示所有goruntine都完成操作了,需要通知等待者了,然后将state值设置成0
   *statep = 0
   for ; w != 0; w-- {
      runtime_Semrelease(semap, false, 0)
   }
}

Wait方法

Wait的主要逻辑增加waiter值,然后进行休眠,等待被唤醒。

func (wg *WaitGroup) Wait() {
   statep, semap := wg.state()  // 获取statep和semap
   for {
      state := atomic.LoadUint64(statep) // 原子获取state
      v := int32(state >> 32) // 获取counter
      w := uint32(state) // 获取waiter值
      
      // counter为0,不需要阻塞
      if v == 0 {
         return
      }
      
      // CAS操作自增waiter值
      if atomic.CompareAndSwapUint64(statep, state, state+1) {
      
         // 阻塞休眠
         runtime_Semacquire(semap)
         
         // 被唤醒后statep不为0,代表出现异常
         if *statep != 0 {
            panic("sync: WaitGroup is reused before previous Wait has returned")
         }
         
         // 退出
         return
      }
   }
}

Done方法

Done方法的逻辑就是调用Add(-1)

func (wg *WaitGroup) Done() {
   wg.Add(-1)
}

总结

  • 调用 Add 方法建议不要传负数,直接调用 Done 方法
  • Add 方法需要在启动 goruntine 前调用,Done 方法需要在 goruntine 完成时调用
  • 调用 Done 的次数超过了 WaitGroupcounter 值,所以需要预先确定好 WaitGroup 的计数值,然后调用相同次数的 Done 完成相应的任务
  • WaitGroup 必须在 Wait 方法返回之后才能再次使用,主要是 Wait 方法和 Add 方法可能存在并发,我的建议是最好不要复用,直接创建一个新的
  • 可以同时有多个 goroutine 等待当前 WaitGroup 的 counter 值归零,这些 goroutine 会被同时唤醒