WaitGroup用于任务编排,解决了并发-等待的问题,比如:协程A等待多个协程都执行完后,再继续执行后续的流程。
package main
import "sync"
func main() {
var wg sync.WaitGroup
// 并发的打印0至9,等待打印完成后输出"done"
for i := 0; i < 10; i++ {
wg.Add(1)
go func(num int){
println(num)
wg.Done()
}(i)
}
wg.Wait()
println("done")
}
WaitGroup原理
WaitGroup源码在
$GOROOT/src/sync/waitgroup.go,为了讲解WaitGroup的原理,本文删减了WaitGroup的源码。
WaitGroup包含statep状态指针和semap信号量指针两个字段。其中,statep状态指针用于表示WaitGroup的状态,高32位为WaitGroup的计数器(当计数器的值为0时,会释放所有调用wg.Wait()的协程),低32位为等待当前WaitGroup的协程数(即调用wg.Wait()方法的协程数量);semap信号量指针用于从等待队列中唤醒调用wg.Wait()的协程。
type WaitGroup struct {
statep *uint64
semap *uint32
}
Add()方法和Done()方法用于操作WaitGroup的计数器,即statep状态指针对应的高32位。当WaitGroup的计数器清零的时候,会释放等待队列中所有的协程。
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
func (wg *WaitGroup) Add(delta int) {
statep := wg.statep
semap := wg.semap
// 采用原子操作,对计数器进行加操作
state := atomic.AddUint64(statep, uint64(delta)<<32)
v := int32(state >> 32) // 计数器
w := uint32(state) // 协程等待数
// 判断是否需要释放等待队列中的协程
// 当计数器v的值大于0时 -> 还有其他协程没有执行完,不用释放等待队列
// 当等待协程数为0时 -> 等待队列为空,无需释放等待队列
if v > 0 || w == 0 {
return
}
// 将state状态清零,这里主要是将等待者协程数量清零,即statep低32位
*statep = 0
// 释放等待队列中所有的协程
for ; w != 0; w-- {
runtime_Semrelease(semap, false, 0)
}
}
Wait()方法会根据WaitGroup计数器的值将协程放到等待队列中。当计数器值为0的时候,不会阻塞当前协程直接返回;当计数器不为0的时候,会将协程放到等待队列中,同时将WaitGroup的协程等待数+1;由于存在多个协程调用Wait()方法的情况,Wait()采用了CAS更新等待数,只有CAS操作成功后才将协程放入等待队列,否则重试。
func (wg *WaitGroup) Wait() {
statep := wg.statep
semap := wg.semap
for {
// 读取状态
state := atomic.LoadUint64(statep)
v := int32(state >> 32) // 计数器
w := uint32(state) // 协程等待数
// 当计数器为0时,不用阻塞当前协程,直接返回
if v == 0 {
return
}
// 采用CAS操作增加等待者的数量,同时将当前协程放到等待队列中
if atomic.CompareAndSwapUint64(statep, state, state+1) {
runtime_Semacquire(semap)
return
}
}
}
WaitGroup源码细节
WaitGroup源码位置:
$GOROOT/src/sync/waitgroup.go。
WaitGroup在具体实现上有一些细节处理,如异常处理、32位机兼容等。
禁止Copy使用
Go语言的函数传递是值复制的,对sync包内的并发原语进行函数传递会出现不预期的情况。如下代码所示:由于值复制,run函数内部的wg和main函数内的wg不是同一个,导致main函数的协程不会等待run函数中协程的执行。
func main() {
var wg sync.WaitGroup
// 由于是值复制的,run函数内部的wg和main函数内的wg不是同一个,会出现直接打印done的情况
run(wg)
wg.Wait()
println("done")
}
func run(wg sync.WaitGroup) {
for i := 0; i < 10; i++ {
go func(num int) {
wg.Add(1)
println(num)
wg.Done()
}(i)
}
}
对此,WaitGroup采用go vet工具进行并发原语值复制的问题。上面的示例在执行完go vet后,会报出下面这样的错误。
go vet ./main.go
# command-line-arguments
./main.go:9:6: call of run copies lock value: sync.WaitGroup contains sync.noCopy
./main.go:14:13: run passes lock by value: sync.WaitGroup contains sync.noCopy
go vet会对有Lock()方法和Unlock()方法的结构体进行值复制的检查,当出现值复制就会报错。对于WaitGroup这种没有Lock()方法和Unlock()方法的结构,通过添加noCopy字段为其添加包内可见的Lock()方法和Unlock()方法,以满足go vet的使用要求。
type WaitGroup struct {
noCopy noCopy
state1 [3]uint32
}
// noCopy
type noCopy struct{}
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
兼容32位机
WaitGroup采用64位的state状态变量,其中高32位为计数器,低32位为协程等待数量。由于32位无法保证64位数据的原子操作,WaitGroup将其拆分为两个32位变量。
type WaitGroup struct {
noCopy noCopy
// wg的state为64位,其中高32位为计数器,低32位为协程等待数量
// 由于32位机无法保证64位数据的原子操作,采用两个独立的32位变量分别表示计数器和协程等待数量
// 同时采用32位存储sema信号量,因此是长度为3的数组
state1 [3]uint32
}
// 因为涉及高位对齐还是低位对齐,因此采用统一的state()函数获取wg的状态和信号量
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}