go 源码解析 - sync.WaitGroup

299 阅读4分钟

个人主页:二郎腿 (erlangtui.top)

本文代码基于 go1.17.13,src/sync/waitgroup.go

一、简述

  • 主 goroutine 阻塞等待一组 goroutine 完成后,再进行后续操作;
  • 常用于对多个服务并发请求后,等待返回后的结果在进行后续处理;

二、基本原理

  • 通过原子操作记录当前正在执行的 goroutine 的数量和等待的 goroutine 数量,当正在执行的 goroutine 计数不为 0 时,主 goroutine 阻塞等待;当正在执行的 goroutine 计数为 0 时唤醒等待的主 goroutine,并检测该计数是否为 0,以防止在主 goroutine 阻塞等待期间WaitGroup被其他 goroutine 复用;
  • 主 goroutine 调用 Add 来设置要等待的 goroutine 数量,并发执行多个数量的子 goroutine,然后主 goroutine 调用 Wait 阻塞等待,直到所有 goroutines 完成,每个子 goroutine 在运行完成后调用 Done 让正在执行的 goroutine 数量减一,数量为 0 时阻塞等待完成;
  • WaitGroup 首次使用后不允许被复制;

三、基本用法

package main

import "sync"

func reqOther() {
}
func doSomething(w *sync.WaitGroup) {
	defer w.Done()
	reqOther()
}
func doContinue() {
}
func main() {
	num := 9
	var w sync.WaitGroup
	w.Add(num)
	for i := 1; i <= 9; i++ {
		go doSomething(&w)
	}
	w.Wait() // 阻塞等待
	doContinue()
}
  • 在主 goroutine 中要先执行Add 操作,再执行Wait操作;
  • 要在子 goroutine 中执行 Done 操作;
  • Add 的数量和执行 Done 的数量要一致;

四、源码解读

1, WaitGroup

type WaitGroup struct {
	noCopy noCopy
	state1 [3]uint32
}
  • noCopy是一个空结构体类型,用于禁止复制,会在编译期间检测,可以通过go vet检测;
  • state1将两个 32 位的值合并为一个 64 位值:高 32 位为计数器,低 32 位为 waiter 计数;
  • 64 位原子操作需要 64 位对齐,但 32 位编译器不能确保这一点;
  • 因此,分配 12 个字节,然后使用其中对齐的 8 个字节作为状态,另外 4 个字节作为 sema 信号的存储;

2, state

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		// 64 位,waiter counter sema
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		// 32 位,sema waiter counter
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}
  • state 返回指向存储在 wg.state1 中的 state 和 sema 字段的指针,计数与信号;
  • 将 waiter counter 两个计数器放进一个 uint64 变量,这样就可以在不加锁的情况下,支持并发场景下的原子操作了,极大地提高了性能

3, Add

func (wg *WaitGroup) Add(delta int) {
	statep, semap := wg.state()
	// statep 高位存储的是 counter,将 delta 左移 32 位,加到 statep 的高位上
	state := atomic.AddUint64(statep, uint64(delta)<<32)
	v := int32(state >> 32) // 右移 32 位 得到实际的 counter 值
	w := uint32(state) // 直接用 32 位截断,得到低位存储的 waiter
	if v < 0 {
		// 计数器小于 0 时,panic
		panic("sync: negative WaitGroup counter")
	}
	if w != 0 && delta > 0 && v == int32(delta) {
		// 已经调用了 Wait,但计数器为零,且 delta 为正,说明 Add 调用在 Wait 之后发生,panic
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	if v > 0 || w == 0 {
		// 计数器大于 0 或是没有调用 wait,不需要后续处理
		return
	}

	// 当 counter = 0,waiters > 0 时,现在不能同时发生状态突变:
	// - Add 不得与 Wait 同时发生,
	// - 如果 Wait 看到计数器 == 0,则不会增加 waiters。
	// 仍然做一个廉价的健全性检查来检测 WaitGroup 的滥用。
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// counter 为 0,说明所有 goroutine 已经调用了 done 操作,重置 waiter 为 0,并逐一唤醒调用 Wait 的 goroutine
	*statep = 0
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}
  • Add 将增量 delta(可能为负数)添加到 WaitGroup 计数器;
  • 如果计数器变为零,则释放所有在 Wait 上阻塞的 goroutine;如果计数器变为负数,则 Add panic;
  • 当计数器为零时,delta 为正的 Add 调用必须在 Wait 之前发生;
  • 当计数器大于零开始,负的 delta Add 调用可能随时发生;
  • 对 Add 的调用应在创建要等待的 goroutine 或其他事件的语句之前执行;
  • 如果重用 WaitGroup 来等待多个独立的事件集,则必须在返回所有以前的 Wait 调用后进行新的 Add 调用;

4, Done

// Done WaitGroup 计数减一
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}
  • 调用Add函数,使 counter 计数减一;

5, Wait

// Wait 阻塞直到 WaitGroup 计数变为 0
func (wg *WaitGroup) Wait() {
	statep, semap := wg.state()
	for {
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32)
		w := uint32(state)
		if v == 0 {
			// 计数器为 0,不需要等待,直接返回
			return
		}
		// 计数器不为 0,说明还有 goroutine 没有调用 Done
		// 等待者 waiters 计数加一
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			// 计数成功后,阻塞等待
			runtime_Semacquire(semap)
			// 阻塞等待完成,其他 goroutine 均已返回,wait 结束,此时 statep 应该为 0
			if *statep != 0 {
				// 如果 statep 不为 0,说明前一次的 wait 还没有返回时,WaitGroup 被复用,直接 panic
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}
  • 如果 counter 计数器为 0,说明所有子 goroutine 均已返回,主 goroutine 无需等待直接返回;
  • 否则,主 goroutine 阻塞等待,等阻塞等待结束后,如果计数还不为 0 ,则说明在结束之前WaitGroup已经被复用了,直接 panic