Go并发8 并发原语 - 任务编排 - WaitGroup的基本实现原理

203 阅读3分钟
WaitGroup

WaitGroup是 package sync 用来做任务编排的一个并发原语,它要解决的就是并发 - 等待的问题:

比如现在有一个 goroutine A 在检查点(checkpoint)等待一组 goroutine 全部完成,如果在执行任务的这些 goroutine 还没全部完成,那么 goroutine A 就会阻塞在检查点,直到所有 goroutine 都完成后才能继续执行。

为什么要使用WaitGroup
假设我们有一个对账系统,流程如下:

  1. 查询订单
  2. 查询派送单
  3. 对比订单和派送单,将差异写入差异库。

3需要等1,2都完成才能开始进行,所以我们需要去轮询1,2看看他们有没有完成,这样子会导致两个问题:

  1. 性能比较低,因为两个小任务可能早就完成了,却要等很长时间才被轮询到。
  2. 会有很多无谓的轮询,空耗 CPU 资源。

在这种情况下使用WG,就可以阻塞1,2,然后即时唤醒3,从而解决上述的两个问题。 其实,很多操作系统和编程语言都提供了类似的并发原语。比如,Linux 中的 barrier、Pthread(POSIX 线程)中的 barrier、C++ 中的 std::barrier、Java 中的 CyclicBarrier 和 CountDownLatch 等。

Java中的并发 - 等待原语实战可以参考笔者的另一篇文章:
Java并发编程实战 - 并发 等待之CountDownLatch & CyclicBarrier - 掘金 (juejin.cn)

WaitGroup 的基本用法

WG的三个方法
Go 标准库中的 WaitGroup 提供了三个方法:

  1. Add,用来设置 WaitGroup 的计数值
  2. Done,用来将 WaitGroup 的计数值减 1,其实就是调用了 Add(-1);
  3. Wait,调用这个方法的 goroutine 会一直阻塞,直到 WaitGroup 的计数值变为 0。
func (wg *WaitGroup) Add(delta int)
func (wg *WaitGroup) Done()
func (wg *WaitGroup) Wait()

代码实例
下面的代码会等计数器进行10次++后输出:

package main

import (
   "fmt"
   "sync"
   "time"
)

// Counter 线程安全的计数器
type Counter struct {
   //定义Mutex紧挨着共享变量
   mu    sync.Mutex
   count uint64
}

// Incr 对计数值加一
func (c *Counter) Incr() {
   c.mu.Lock()
   c.count++
   c.mu.Unlock()
}

// Count 获取当前的计数值
func (c *Counter) Count() uint64 {
   c.mu.Lock()
   defer c.mu.Unlock()
   return c.count
}

// sleep 1秒,然后计数值加1
func worker(c *Counter, wg *sync.WaitGroup) {
   defer wg.Done() //-1
   time.Sleep(time.Second)
   c.Incr()
}
func main() {
   var counter Counter

   var wg sync.WaitGroup
   wg.Add(10)                // WaitGroup的值设置为10
   for i := 0; i < 10; i++ { // 启动10个goroutine执行加1任务
      go worker(&counter, &wg)
   }
   // 检查点,等待goroutine都完成任务
   wg.Wait()
   // 输出当前计数器的值
   fmt.Println(counter.Count())
}

WaitGroup实现

WaitGroup 的数据结构定义

  1. noCopy 的辅助字段,主要就是辅助 vet 工具检查是否通过 copy 赋值这个 WaitGroup 实例。
  2. state1,一个具有复合意义的字段,包含 WaitGroup 的计数、阻塞在检查点的 waiter 数和信号量
type WaitGroup struct {
 // 避免复制使用的一个技巧,可以告诉vet工具违反了复制使用的规则
 noCopy noCopy
 // 64bit(8bytes)的值分成两段,高32bit是计数值,低32bit是waiter的计数
 // 另外32bit是用作信号量的
 // 因为64bit值的原子操作需要64bit对齐,但是32bit编译器不支持,所以数组中的元素在不同的
 // 总之,会找到对齐的那64bit作为state,其余的32bit做信号量
 state1 [3]uint32
}

state1信息如何获取

// 得到state的地址和信号量的地址
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
   if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
      // 如果地址是64bit对齐的,数组前两个元素做state,后一个元素做信号量
      return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
   } else {
      // 如果地址是32bit对齐的,数组后两个元素用来做state,它可以用来做64bit的原子操作
      return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
   }
}

state信息在不同位数操作系统的不同表现
因为对 64 位整数的原子操作要求整数的地址是 64 位对齐的,所以针对 64 位和 32 位环 境的 state 字段的组成是不一样的。

64位
在 64 位环境下,state1 的第一个元素是 waiter 数第二个元素是 WaitGroup 的计数值,第三个元素是信号量。

image.png

32位
在 32 位环境下,如果 state1 不是 64 位对齐的地址,那么 state1 的第一个元素是信号 量,后两个元素分别是 waiter 数和计数值。

image.png

WaitGroup方法实现
Add & Done
Add方法内部通过原子操作把一个delta值加到计数值上,这个 delta 也可以是个负数,相当于为计数值减去一个值,Done 方法内部其实就是通过 Add(-1) 实现的。

func (wg *WaitGroup) Add(delta int) {
 statep, semap := wg.state()
 // 高32bit是计数值v,所以把delta左移32,增加到计数上
 state := atomic.AddUint64(statep, uint64(delta)<<32)
 v := int32(state >> 32) // 当前计数值
 w := uint32(state) // waiter count
 if v > 0 || w == 0 {
 return
 }
 // 如果计数值v为0并且waiter的数量w不为0,那么state的值就是waiter的数量
 // 将waiter的数量设置为0,因为计数值v也是0,所以它们俩的组合*statep直接设置为0即可。此
 *statep = 0
 for ; w != 0; w-- {
 runtime_Semrelease(semap, false, 0)
 }
}
// Done方法实际就是计数器减1
func (wg *WaitGroup) Done() {
 wg.Add(-1)
}

Wait
Wait会不断检查 state 的值。如果其中的计数值变为了 0,那么说明所有的任务已完成,调用者不必再等待,直接返回。如果计数值大于 0,说明此时还有任务没完成,那么调用者就变成了等待者,需要加入 waiter 队列,并且阻塞住自己

func (wg *WaitGroup) Wait() {
 statep, semap := wg.state()
 
 for {
 state := atomic.LoadUint64(statep)
 v := int32(state >> 32) // 当前计数值
 w := uint32(state) // waiter的数量
 if v == 0 {
 // 如果计数值为0, 调用这个方法的goroutine不必再等待,继续执行它后面的逻辑即可
 return
 }
 // 否则把waiter数量加1。期间可能有并发调用Wait的情况,所以最外层使用了一个for循环
 if atomic.CompareAndSwapUint64(statep, state, state+1) {
 // 阻塞休眠等待
 runtime_Semacquire(semap)
 // 被唤醒,不再阻塞,返回
 return
 }
 }
}