Golang编程之WaitGroup

112 阅读5分钟

WaitGroup,我们以前都多多少少学习过,或者是使用过。其实,WaitGroup 很简单,就是package sync用来做任务编排的一个并发原语。它要解决的就是并发-等待的问题: 现在有个goroutine A在检查点 (checkpoint)等待一组 goroutine 全部完成,如果在执行任务的这些 goroutine 还没全部完成,那么 goroutine A 就会阻塞在检查点,直到所有goroutine 都完成后才能继续执行。

比如,我们要完成一个大的任务,需要使用并行的 goroutine 执行三个小任务,只有这三个小任务都完成,我们才能去执行后面的任务。如果通过轮询的方式定时询问三个小任务是否完成,会存在两个问题:一是,性能比较低,因为三个小任务可能早就完成了,却要等很长时间才被轮询到;二是,会有很多无谓的轮询,空耗 CPU 资源 那么,这个时候使用 WaitGroup 并发原语就比较有效了,它可以阻塞等待的 goroutine。等到三个小任务都完成了,再即时唤醒它们。

WaitGroup基本用法

func (wg *WaitGroup) Add(delta int)
func (wg *WaitGroup) Done()
func (wg *WaitGroup) Wait()

我们分别看下这三个方法

  • Add,用来设置WaitGroup 的计数值
  • Done,用来将 WaitGroup 的计数值减 1,其实就是调用了 Add(-1);
  • Wait,调用这个方法的 goroutine 会一直阻塞,直到 WaitGroup 的计数值变为0。

接下来,我们通过一个使用 WaitGroup 的例子,来看下Add、Done、Wait 方法的基本用法。 在这个例子中,我们使用了以前实现的计数器 struct。我们启动了10个 worker,分别对计数值加一,10个 worker 都完成后,我们期望输出计数器的值。

type Counter struct {
	mu sync.Mutex
	count uint64
}

// 对计数值加一
func (c *Counter) Incr() {
	c.mu.Lock()
	c.count++
	c.mu.Unlock()
}

// 获取当前的计数值
func (c *Counter) Count() uint64 {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.count
}

// sleep 1秒,然后计数值加1
func worker(c *Counter, wg *sync.WaitGroup) {
	defer wg.Done()
	time.Sleep(time.Second)
	c.Incr()
}

func main() {
	var counter Counter
	var wg sync.WaitGroup
	
	wg.Add(10) // WaitGroup的值设置为10
	for i := 0; i < 10; i++ { // 启动10个goroutine执行加1任务
		go worker(&counter, &wg)
	}
	
	// 检查点,等待goroutine都完成任务
	wg.Wait()
	
	// 输出当前计数器的值
	fmt.Println(counter.Count())
}

WaitGroup的实现

首先,我们看看 WaitGroup 的数据结构。它包括了一个 noCopy 的辅助字段,一个 state1记录 WaitGroup 状态的数组。

  • noCopy 的辅助字段,主要就是辅助 vet 工具检查是否通过 copy 赋值这个 WaitGroup 实例。我会在后面和你详细分析这个字段:
  • state1,一个具有复合意义的字段,包含 WaitGroup 的计数、阻塞在检查点的 waiter 数和信号量。

WaitGroup的数据结构定义以及 state 信息的获取方法如下

type WaitGroup struct {
	// 避免复制使用的一个技巧,可以告诉vet工具违反了复制使用的规则
	noCopy noCopy
	
	// 64bit(8bytes)的值分成两段,高32bit是计数值,低32bit是waiter的计数
	// 另外32bit是用作信号量的
	// 因为64bit值的原子操作需要64bit对齐,但是32bit编译器不支持,所以数组中的元素在不同的
	// 总之,会找到对齐的那64bit作为state,其余的32bit做信号量
	state1 [3]uint32
}

// 得到state的地址和信号量的地址
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		// 如果地址是64bit对齐的,数组前两个元素做state,后一个元素做信号量
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		// 如果地址是32bit对齐的,数组后两个元素用来做state,它可以用来做64bit的原子操作
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}

因为对64位整数的原子操作要求整数的地址是64位对齐的,所以针对64位和32位环境的state 字段的组成是不一样的。 在64位环境下,state1的第一个元素是 waiter 数,第二个元素是 WaitGroup 的计数值第三个元素是信号量。

在32位环境下,如果state1 不是64 位对齐的地址,那么state1 的第一个元素是信号量后两个元素分别是 waiter 数和计数值。

然后,我们继续深入源码,看一下 Add、Done 和Wait 这三个方法的实现。

在查看这部分源码实现时,我们会发现,除了这些方法本身的实现外,还会有一些额外的代码,主要是 race 检查和异常检查的代码。其中,有几个检查非常关键,如果检查不通过,会出现 panic,这部分内容我会在下一小节分析 WaitGroup 的错误使用场景时介绍。现在,我们先专注在Add、Wait 和 Done 本身的实现代码上

我先为你梳理下Add 方法的逻辑。Add 方法主要操作的是 state 的计数部分。你可以为计数值增加一个 delta 值,内部通过原子操作把这个值加到计数值上。需要注意的是,这个 delta也可以是个负数,相当于为计数值减去一个值,Done 方法内部其实就是通过 Add(-1)实现的。

它的实现代码如下

func (wg *WaitGroup) Add(delta int) {
statep, semap := wg.state()
// 高32bit是计数值v,所以把delta左移32,增加到计数上
state := atomic.AddUint64(statep, uint64(delta)<<32)
v := int32(state >> 32) // 当前计数值
w := uint32(state) // waiter count
if v > 0 || w == 0 {
return
}
// 如果计数值v为0并且waiter的数量w不为0,那么state的值就是waiter的数量
// 将waiter的数量设置为0,因为计数值v也是0,所以它们俩的组合*statep直接设置为0即可。此时需
*statep = 0
for ; w != 0; w-- {
runtime_Semrelease(semap, false, 0)
}
}
// Done方法实际就是计数器减1
func (wg *WaitGroup) Done() {
wg.Add(-1)
}

Wait 方法的实现逻辑是:不断检查 state 的值。如果其中的计数值变为了0,那么说明所有的任务已完成,调用者不必再等待,直接返回。如果计数值大于 0,说明此时还有任务没完成,那么调用者就变成了等待者,需要加入 waiter 队列,并且阻塞住自己

其主实现代码如下

func (wg *WaitGroup) Wait() {
	statep, semap := wg.state()
	for {
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32) // 当前计数值
		w := uint32(state) // waiter的数量
		
		if v == 0 {
			// 如果计数值为0, 调用这个方法的goroutine不必再等待,继续执行它后面的逻辑即可
			return
		}
		
		// 否则把waiter数量加1。期间可能有并发调用Wait的情况,所以最外层使用了一个for循环
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			// 阻塞休眠等待
			runtime_Semacquire(semap)
			// 被唤醒,不再阻塞,返回
			return
		}
	}
}

使用WaitGroup 时的常见错误

** 常见问题一:计数器设置为负值 **

WaitGroup 的计数器的值必须大于等于 0。我们在更改这个计数值的时候,WaitGroup 会先做检查,如果计数值被设置为负数,就会导致 panic。

般情况下,有两种方法会导致计数器设置为负数

第一种方法是:调用Add 的时候传递一个负数。如果你能保证当前的计数器加上这个负数后还是大于等于0的话,也没有问题,否则就会导致 panic。

第二个方法是: 调用 Done 方法的次数过多,超过了 WaitGroup 的计数值

使用 WaitGroup 的正确姿势是,预先确定好 WaitGroup 的计数值,然后调用相同次数的Done 完成相应的任务。比如,在 WaitGroup 变量声明之后,就立即设置它的计数值,或者在goroutine 启动之前增加 1,然后在 goroutine 中调用 Done。

常见问题二:不期望的 Add 时机

等所有的Add 方法调用之后再调用在使用 WaitGroup 的时候,你一定要遵循的原则就是,Wait,否则就可能导致 panic 或者不期望的结果

常见问题三:前一个 Wait 还没结束就重用 WaitGroup

前一个 Wait 还没结束就重用 WaitGroup”这一点似乎不太好理解,我借用田径比赛的例子和你解释下吧。在田径比赛的百米小组赛中,需要把选手分成几组,一组选手比赛完之后,就可以进行下一组了。为了确保两组比赛时间上没有冲突,我们在模型化这个场景的时候,可以使用WaitGroup。

WaitGroup 等一组比赛的所有选手都跑完后 5 分钟,才开始下一组比赛。下一组比赛还可以使用这个 WaitGroup 来控制,因为 WaitGroup 是可以重用的。只要 WaitGroup 的计数值恢复到零值的状态,那么它就可以被看作是新创建的 WaitGroup,被重复使用。

但是,如果我们在 WaitGroup 的计数值还没有恢复到零值的时候就重用,就会导致程序panic.

WaitGroup 虽然可以重用,但是是有一个前提的,那就是必须等到上一轮的 Wait完成之后,才能重用 WaitGroup 执行下一轮的 Add/Wait,如果你在 Wait 还没执行完的时候就调用下一轮 Add 方法,就有可能出现 panic。