sync.waitgroup 是一个常用的用来协调等待协程结束的组件。
比如在下面的这段代码上,我们通过 waitgroup 可以让 main 等待协程退出后再退出:
func main() {
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
time.Sleep(2 * time.Second)
fmt.Println("goroutine exit.")
wg.Done()
}()
wg.Wait()
fmt.Println("main exit.")
}
复制代码
那么,waitgroup 是怎么实现的呢,我们来详解一下。
结构定义
waitgroup 的结构定义非常简单,但是涉及到了几个重要的知识点。
type WaitGroup struct {
noCopy noCopy
state1 [3]uint32
}
复制代码
- noCopy 表示了可以做静态检查,不允许拷贝实例使用;
- state1 这里面包含了 3种状态:
- 64位值:高 32 位 计数的数量;低 32位 等待 goroutine 的数量
- 32位值:还有 Semaphore;
根据 atomic 官方文档 最后一段中,对于 64位数在32位平台上的操作时,强制要求使用 8 字节对齐,否则就会出现问题。而如果保证 waitgroup 在32位平台上使用的话,就必须保证在任何时候,64位的操作不会出错。
所以,并不能直接在这里将变量申明成下面的样子,原因是因为我们并不能确定 counter 是不是在 8 字节对齐的位置上(即便互换了 sema 和 counter 也不行)。
type WaitGroup struct {
noCopy noCopy
counter uint64
sema uint32
}
复制代码
这就需要有一个办法,来动态的识别当前我们操作的64位数,到底是不是在 8 字节对齐的位置上面。WaitGroup 通过申明一个 12 字节的数组,并实现了一个内部方法 state()
来保证这一点:
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
// 当数组的首地址是在一个 8 字节对齐的位置上时
// 那么就将数组中的前 8 个字节作为64位值使用
// 后 4 个字节作为 semaphore
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
// 如果首地址没有在 8 字节对齐的位置上时
// 那么,就将前 4 个字节作为 semaphore
// 后 8 个字节作为 64位计数值
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}
复制代码
增加/减少计数
WaitGroup 中,使用 Add()
增加一个计数,当需要减掉一个计数时,使用 Done()
。但实际上 Done()
调用的还是 Add()
,只不过增加的是 -1:
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
复制代码
因为,增减的逻辑都放在了 Add()
中,而调用者可以随意传入正负数值到函数中,所以需要考虑两种异常情况:
- 计数为负数;
- 当有等待者等待的时候,并发调用 Add();
下面的代码详解,去掉了 race 检查的部分:
func (wg *WaitGroup) Add(delta int) {
// 获取到计数和 semaphore
statep, semap := wg.state()
// 给高32位增加 delta 的计数
state := atomic.AddUint64(statep, uint64(delta)<<32)
// 获取到计数的值
v := int32(state >> 32)
// 获取 semaphore 的值
w := uint32(state)
// 计数小于 0 ,异常 panic
if v < 0 {
panic("sync: negative WaitGroup counter")
}
// 有等待者的时候,并发调用了 Add, 异常 panic
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 计数大于 0 正常返回
// 没有等待者,也不需要后续操作
if v > 0 || w == 0 {
return
}
// ------- 最后的情况 计数 == 0 --------
// 有等待者的时候,并发调用了 Add, 异常 panic
if *statep != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
// 重置 waiter 等于 0
*statep = 0
for ; w != 0; w-- {
// 按顺序通知等待的 goroutine
runtime_Semrelease(semap, false, 0)
}
}
复制代码
Wait 等待计数归零
func (wg *WaitGroup) Wait() {
// 获取状态
statep, semap := wg.state()
for { // 因为有 CAS,所以要放到循环中,保证成功
// load 状态值
state := atomic.LoadUint64(statep)
// 获取计数
v := int32(state >> 32)
// 获取等待者数量
w := uint32(state)
// 计数为0 直接返回
if v == 0 {
return
}
// 增加等待者的数量
if atomic.CompareAndSwapUint64(statep, state, state+1) {
// 增加成功,等待信号量
runtime_Semacquire(semap)
// 通知计数归零了,如果状态值不为零,那么认为是有问题的,详见 Add()
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
return
}
}
}
复制代码
结尾
整个 WaitGroup 中,实现相对其他库比较简单。但是对于 8 字节对齐的处理很有意思,值得在开发中借鉴。