深入浅出 WaitGroup 和 Errgroup

4,101 阅读8分钟

我们通常会遇到需要将一个大的任务拆分成若干个没有先后关系的子任务并发执行的场景。这么拆分的主要目的就是减少接口的响应时常,提高程序的并发度,从而优化用户体验。这就像在厨房做饭一样,我们可以在蒸米饭的同时去炒几个小菜,这样的话,等米饭蒸好了之后,菜也做好了,很快就可以享受到自己亲手做的美食了。

Golang 并发工具包中也提供了一些能够实现类似功能的并发原语,例如 WaitGroup。

使用

从 WaitGroup 的注释中我们可以知道:WaitGroup 主要用于等待一组 goroutines 执行完毕。主 goroutine 会通过调用 Add() 来设置需要等待的子 goroutine 的数量,子 goroutine 会通过调用 Done() 来说明自己已经执行完毕,主 goroutine 通过调用 Wait() 来等待所有子 goroutine 全部执行完毕,然后继续进行接下来的操作。示例如下:

func main() {
    // 开箱即用,只需要声明即可
	var wg sync.WaitGroup
	
	wg.Add(2)
	go func() {
		defer wg.Done()
		fmt.Print("task 1")
	}()
	go func() {
		defer wg.Done()
		fmt.Print("task 2")
	}()
	wg.Wait()
	
	fmt.Print("finish")
}

源码分析

上面就是 WaitGroup 的整个使用流程,整体还是比较简单的。WaitGroup 对外一共暴露了一个结构体和三个方法:

WaitGroup

type WaitGroup struct {
    // 避免复制使用的一个技巧,可以告诉vet工具违反了复制使用的规则
    // 这个在 go sync 包中有很多它出现的身影
	noCopy noCopy
	
	// 64 位:高 32 位作为计数器,低 32 位作为 waiter 计数
	// 64 位的原子操作要求 64 位对齐,但 32 位编译器无法保证这个要求
	// 因此分配 12 字节,然后将其中对齐的 8 字节作为状态,其他 4 字节用于存储原语
	state1 [3]uint32
}

WaitGroup 通过内存对齐来存储技术器、waiter 数以及信号量的值。

我们都知道 CPU 为了加快对内存的访问速度,会对内存进行对齐处理,也就是把内存当成一块一块的,每块的大小可以是 2、4、8、16 字节,CPU 读数据的时候是按照块进行读取的,块的大小被称为 memory access granularity(内存访问粒度)。在不同平台上的编译器都有自己默认的内存访问粒度,一般情况下,对于 32 位系统访问粒度是 4,64 位系统的访问粒度是 8。state1 的设计兼容了这两种平台的访问粒度,它是一个有 3 个元素的 uint32 类型(4 字节)的数组,共 12 字节,这样对于 32 位系统来说,每次访问都能访问到一个元素值,但是对于 64 位来说,访问粒度是 8,一次访问会读取到两个元素(4 x 2 = 8)。基于此事实,对于 64 位和 32 位系统,state1 数组元素分别被赋予了不同的含义:

state1[0]state1[1]state1[2]
64 位waitercountersemaphore
32 位semaphorewaitercounter
// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    // 64 位系统,内存访问粒度是 8
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
	    // 64 位,前两个元素存放 state
	    // 将两个 uint32 类型的数据赋值给一个 uint64 类型的数据
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
	    // 32 位,后两个元素存放 state
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}

Add

Add() 方法用来增加设置计数器的值,这里可以传入负数,代表计数器的值减小,但是正常情况下,我们不应该主动传入负值。

func (wg *WaitGroup) Add(delta int) {
	statep, semap := wg.state()
	// 竞态检查
	if race.Enabled {
		_ = *statep // trigger nil deref early
		if delta < 0 {
			// Synchronize decrements with Wait.
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable()
		defer race.Enable()
	}
	// delta 左移 32 位添加到计数器上面
	state := atomic.AddUint64(statep, uint64(delta)<<32)
	// v 代表 Add() 完之后当前计数器的值,取高 32 位的值
	v := int32(state >> 32)
	// w 代表当前调用 Wait 被阻塞的数量
	w := uint32(state)
	if race.Enabled && delta > 0 && v == int32(delta) {
		// The first increment must be synchronized with Wait.
		// Need to model this as a read, because there can be
		// several concurrent wg.counter transitions from 0.
		race.Read(unsafe.Pointer(semap))
	}
	// 非法
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	// w != 0,说明已经执行了 Wait() 操作,此时不允许再执行 Add()
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	if v > 0 || w == 0 {
		return
	}
	// v == 0 && w > 0
	// 此时不能再有一些状态的并发改变的问题:
	// - Add() 和 Wait() 操作不能并发执行
	// - 如果计数器的值已经是 0 了,此时不能再执行 Wait() 操作
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 将 waiter 计数设置为 0,并且唤醒所有 waiter
	// 由于 v 和 w 都是 0,所以这里直接将 *statep 设置为 0 就行
	*statep = 0
	// 唤醒所有 waiter
	for ; w != 0; w-- {
	    // 释放信号量
		runtime_Semrelease(semap, false, 0)
	}
}

Done

Done() 方法比较简单,内部就是简单的调用了 Add() 方法,参数传 -1,将计数器的值减 1,代表当前协程工作完毕。

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Wait

Wait() 方法在子 goroutine 执行完毕之前需要阻塞主 goroutine,其实现就是内部开了一个死循环,不停检查计数器的值,直到其为 0 才结束。

func (wg *WaitGroup) Wait() {
	statep, semap := wg.state()
	// 竞态检查
	if race.Enabled {
		_ = *statep // trigger nil deref early
		race.Disable()
	}
	// 启动循环
	for {
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32)
		w := uint32(state)
		if v == 0 {
			// 计数器已经变成 0 了,不需要再等待,直接返回
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// 增加 waiter 数量(CAS)
		// 直接在 state 的低位加就行,也就是直接 +1
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			if race.Enabled && w == 0 {
				// Wait must be synchronized with the first Add.
				// Need to model this is as a write to race with the read in Add.
				// As a consequence, can do the write only for the first waiter,
				// otherwise concurrent Waits will race with each other.
				race.Write(unsafe.Pointer(semap))
			}
			// 等待信号量唤醒
			runtime_Semacquire(semap)
			// 这种情况说明在上一轮 Wait() 返回之前,wg 被重新使用了(重新进行了 Add() / Wait())
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}

关于 WaitGroup 的源码解析就讲到这里了。代码很简单,加上注释一共也才 140 行左右,虽然代码行数比较少,但是其中也考虑到了很多异常使用的情况,对 WaitGroup 的使用做了很多规范和限制。接下来就让我们来看一下使用 WaitGroup 的一些注意事项。

WaitGroup 使用的一些注意事项

  1. 在任何时候都不要使计数器的值小于 0 ,这会引发程序的 panic。
  2. Add() 方法的首次调用,与对它的 Wait() 方法的调用不能同时发生,例如在两个不同的 goroutine 中分别调用这两个方法,否则也会引发 panic。因此我们在声明完 WaitGroup 的时候要尽早调用 Add() 方法。
  3. 如果想要重复使用 WaitGroup,我们需要等待前一轮调用 Wait() 返回之后再发起下一轮的调用。
  4. 调用 Done() 方法的次数要与 Add() 的计数器值相等,否则将会 panic。

Errgroup

WaitGroup 的确是一个很强大的工具,但是使用它相对来说还是有一点小麻烦,一方面我们需要自己手动调用 Add() 和 Done() 方法,一旦这两个方法有一个多调用或者少调用,最终都有可能导致程序崩溃,所以我们在使用这两个方法的时候要格外小心,确保最终计数器能够达到 0 的状态;另一方面就是它不能抛出错误给调用者,所以我们只能通过声明多个外部变量的方式(或者声明一个变量然后通过加锁来更新它的值)来分别接收每个协程的 error 才行,就像下面的代码:

func main() {
	var (
		wg sync.WaitGroup
		err1, err2 error
	)

	wg.Add(2)
	go func() {
		defer wg.Done()
		fmt.Print("task 1")
		err1 = nil
	}()
	go func() {
		defer wg.Done()
		fmt.Print("task 2")
		err2 = fmt.Errorf("task 2 error")
	}()
	wg.Wait()
	
	if err1 != nil || err2 != nil {
		// TODO
	}

	fmt.Print("finish")
}

因此,除了 WaitGroup 之外,Golang 还额外提供了另外一个更加好用的工具 -- Errgroup。

Errgroup 是 Golang 官方提供的一个同步扩展库(传送门)。它和 WaitGroup 的作用类似,但是它提供了更加丰富的功能以及更低的使用成本:

  • 和 context 集成;
  • 能够对外传播 error,可以把子任务的错误传递给 Wait 的调用者。

Errgroup 的代码非常简短,加上注释一共才 66 行,包含一个结构体以及三个对外暴露的方法,接下来就让我们走进源码,来具体看一下它是如何工作的。

Group

// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
	cancel func()
    // 底层还是依托 WaitGroup 实现
	wg sync.WaitGroup
    // 通过 sync.Once 来实现只接收第一个错误
	errOnce sync.Once
	// 如果子任务发生了错误,这里接收出现的第一个 error
	err     error
}

WithContext

// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
	ctx, cancel := context.WithCancel(ctx)
	return &Group{cancel: cancel}, ctx
}

Go

// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
    // wg.Add(1) 计数器加 1
	g.wg.Add(1)

	go func() {
		defer g.wg.Done()

		if err := f(); err != nil {
		    // 如果有 error,则记录发生的第一个 error
			g.errOnce.Do(func() {
				g.err = err
				if g.cancel != nil {
					g.cancel()
				}
			})
		}
	}()
}

Wait

// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
    // wg.Wait() 等待所有任务执行完毕
	g.wg.Wait()
	if g.cancel != nil {
		g.cancel()
	}
	return g.err
}

使用 Errgroup,上述代码可以改为下面的样子:

func main() {
    var eg errgroup.Group
    
    eg.Go(func() error {
        fmt.Print("task 1")
		return nil
    })
    eg.Go(func() error {
        fmt.Print("task 2")
		return fmt.Errorf("task 2 error")
    })
    
    if err := eg.Wait(); err != nil {
        // TODO
    }
}

看,是不是简洁了许多!

总结

本文我们介绍了 WaitGroup 和 Errgroup 两个可以用于一对多的 goroutine 协作流程。其中 Errgroup 是对 WaitGroup 的简单封装,提供了更加简单的操作流程。当然除了 Golang 官方提供的扩展库之外,还有很多类似的其他优秀开源工具,例如 bilibili/errgroup,支持设置固定数量的协程数以及失败 cancel 机制和 panic-recover 机制等等,感兴趣的同学可以自行去了解一番。