阅读 110

#Golang源码系列 sync.waitgroup 源码剖析

sync.waitgroup 是一个常用的用来协调等待协程结束的组件。

比如在下面的这段代码上,我们通过 waitgroup 可以让 main 等待协程退出后再退出:

func main() {

	wg := sync.WaitGroup{}

	wg.Add(1)
	go func() {
		time.Sleep(2 * time.Second)
		fmt.Println("goroutine exit.")
		wg.Done()
	}()

	wg.Wait()
	fmt.Println("main exit.")
}
复制代码

那么,waitgroup 是怎么实现的呢,我们来详解一下。

结构定义

waitgroup 的结构定义非常简单,但是涉及到了几个重要的知识点。

type WaitGroup struct {
	noCopy noCopy
	state1 [3]uint32
}
复制代码
  • noCopy 表示了可以做静态检查,不允许拷贝实例使用;
  • state1 这里面包含了 3种状态:
    • 64位值:高 32 位 计数的数量;低 32位 等待 goroutine 的数量
    • 32位值:还有 Semaphore;

根据 atomic 官方文档 最后一段中,对于 64位数在32位平台上的操作时,强制要求使用 8 字节对齐,否则就会出现问题。而如果保证 waitgroup 在32位平台上使用的话,就必须保证在任何时候,64位的操作不会出错。

所以,并不能直接在这里将变量申明成下面的样子,原因是因为我们并不能确定 counter 是不是在 8 字节对齐的位置上(即便互换了 sema 和 counter 也不行)。

type WaitGroup struct {
	noCopy  noCopy
	counter uint64
	sema    uint32
}
复制代码

这就需要有一个办法,来动态的识别当前我们操作的64位数,到底是不是在 8 字节对齐的位置上面。WaitGroup 通过申明一个 12 字节的数组,并实现了一个内部方法 state() 来保证这一点:

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	// 当数组的首地址是在一个 8 字节对齐的位置上时
	// 那么就将数组中的前 8 个字节作为64位值使用
	// 后 4 个字节作为 semaphore
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		// 如果首地址没有在 8 字节对齐的位置上时
		// 那么,就将前 4 个字节作为 semaphore
		// 后 8 个字节作为 64位计数值
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}
复制代码

增加/减少计数

WaitGroup 中,使用 Add() 增加一个计数,当需要减掉一个计数时,使用 Done() 。但实际上 Done() 调用的还是 Add(),只不过增加的是 -1:

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}
复制代码

因为,增减的逻辑都放在了 Add() 中,而调用者可以随意传入正负数值到函数中,所以需要考虑两种异常情况:

  1. 计数为负数;
  2. 当有等待者等待的时候,并发调用 Add();

下面的代码详解,去掉了 race 检查的部分:

func (wg *WaitGroup) Add(delta int) {
	// 获取到计数和 semaphore
	statep, semap := wg.state()

	// 给高32位增加 delta 的计数
	state := atomic.AddUint64(statep, uint64(delta)<<32)

	// 获取到计数的值
	v := int32(state >> 32)
    
	// 获取 semaphore 的值
	w := uint32(state)

	// 计数小于 0 ,异常 panic
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
    
	// 有等待者的时候,并发调用了 Add, 异常 panic
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
    
	// 计数大于 0 正常返回
	// 没有等待者,也不需要后续操作
	if v > 0 || w == 0 {
		return
	}

	// ------- 最后的情况 计数 == 0  --------
    
	// 有等待者的时候,并发调用了 Add, 异常 panic
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// Reset waiters count to 0.
	// 重置 waiter 等于 0
	*statep = 0
	for ; w != 0; w-- {
		// 按顺序通知等待的 goroutine
		runtime_Semrelease(semap, false, 0)
	}
}
复制代码

Wait 等待计数归零

func (wg *WaitGroup) Wait() {
	// 获取状态
	statep, semap := wg.state()

	for {  // 因为有 CAS,所以要放到循环中,保证成功
		// load 状态值
		state := atomic.LoadUint64(statep)
        
		// 获取计数
		v := int32(state >> 32)
        
		// 获取等待者数量
		w := uint32(state)
        
		// 计数为0 直接返回
		if v == 0 {
			return
		}
		// 增加等待者的数量
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			// 增加成功,等待信号量
			runtime_Semacquire(semap)
            
			// 通知计数归零了,如果状态值不为零,那么认为是有问题的,详见 Add()
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
            
			return
		}
	}
}
复制代码

结尾

整个 WaitGroup 中,实现相对其他库比较简单。但是对于 8 字节对齐的处理很有意思,值得在开发中借鉴。