[Go并发] - WaitGroup源码解析

470 阅读2分钟

WaitGroup用于任务编排,解决了并发-等待的问题,比如:协程A等待多个协程都执行完后,再继续执行后续的流程。

package main

import "sync"

func main() {
    var wg sync.WaitGroup
    // 并发的打印0至9,等待打印完成后输出"done"
    for i := 0; i < 10; i++ {
            wg.Add(1)
            go func(num int){
            println(num)
            wg.Done()
        }(i)
    }
    wg.Wait()
    println("done")
}

WaitGroup原理

WaitGroup源码在$GOROOT/src/sync/waitgroup.go,为了讲解WaitGroup的原理,本文删减了WaitGroup的源码。

WaitGroup包含statep状态指针和semap信号量指针两个字段。其中,statep状态指针用于表示WaitGroup的状态,高32位为WaitGroup的计数器(当计数器的值为0时,会释放所有调用wg.Wait()的协程),低32位为等待当前WaitGroup的协程数(即调用wg.Wait()方法的协程数量);semap信号量指针用于从等待队列中唤醒调用wg.Wait()的协程。

type WaitGroup struct {
    statep *uint64
    semap *uint32
}

Add()方法和Done()方法用于操作WaitGroup的计数器,即statep状态指针对应的高32位。当WaitGroup的计数器清零的时候,会释放等待队列中所有的协程。

func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

func (wg *WaitGroup) Add(delta int) {

    statep := wg.statep
    semap  := wg.semap

    // 采用原子操作,对计数器进行加操作
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)   // 计数器
    w := uint32(state)	      // 协程等待数

    // 判断是否需要释放等待队列中的协程
    // 当计数器v的值大于0时 -> 还有其他协程没有执行完,不用释放等待队列
    // 当等待协程数为0时 -> 等待队列为空,无需释放等待队列
    if v > 0 || w == 0 {
        return
    }

    // 将state状态清零,这里主要是将等待者协程数量清零,即statep低32位
    *statep = 0
    // 释放等待队列中所有的协程
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}

Wait()方法会根据WaitGroup计数器的值将协程放到等待队列中。当计数器值为0的时候,不会阻塞当前协程直接返回;当计数器不为0的时候,会将协程放到等待队列中,同时将WaitGroup的协程等待数+1;由于存在多个协程调用Wait()方法的情况,Wait()采用了CAS更新等待数,只有CAS操作成功后才将协程放入等待队列,否则重试。

func (wg *WaitGroup) Wait() {

    statep := wg.statep
    semap := wg.semap

    for {
        // 读取状态
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)		// 计数器
        w := uint32(state)		// 协程等待数

        // 当计数器为0时,不用阻塞当前协程,直接返回
        if v == 0 {
            return
        }

        // 采用CAS操作增加等待者的数量,同时将当前协程放到等待队列中
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            runtime_Semacquire(semap)
            return
        }
    }
}

WaitGroup源码细节

WaitGroup源码位置:$GOROOT/src/sync/waitgroup.go

WaitGroup在具体实现上有一些细节处理,如异常处理、32位机兼容等。

禁止Copy使用

Go语言的函数传递是值复制的,对sync包内的并发原语进行函数传递会出现不预期的情况。如下代码所示:由于值复制,run函数内部的wg和main函数内的wg不是同一个,导致main函数的协程不会等待run函数中协程的执行。

func main() {
    var wg sync.WaitGroup
    // 由于是值复制的,run函数内部的wg和main函数内的wg不是同一个,会出现直接打印done的情况
    run(wg)
    wg.Wait()
    println("done")
}

func run(wg sync.WaitGroup) {
    for i := 0; i < 10; i++ {
        go func(num int) {
            wg.Add(1)
            println(num)
            wg.Done()
        }(i)
    }
}

对此,WaitGroup采用go vet工具进行并发原语值复制的问题。上面的示例在执行完go vet后,会报出下面这样的错误。

go vet ./main.go

# command-line-arguments
./main.go:9:6: call of run copies lock value: sync.WaitGroup contains sync.noCopy
./main.go:14:13: run passes lock by value: sync.WaitGroup contains sync.noCopy

go vet会对有Lock()方法和Unlock()方法的结构体进行值复制的检查,当出现值复制就会报错。对于WaitGroup这种没有Lock()方法和Unlock()方法的结构,通过添加noCopy字段为其添加包内可见的Lock()方法和Unlock()方法,以满足go vet的使用要求。

type WaitGroup struct {
    noCopy noCopy
    state1 [3]uint32
}

// noCopy
type noCopy struct{}
func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}

兼容32位机

WaitGroup采用64位的state状态变量,其中高32位为计数器,低32位为协程等待数量。由于32位无法保证64位数据的原子操作,WaitGroup将其拆分为两个32位变量。

type WaitGroup struct {
    noCopy noCopy
    // wg的state为64位,其中高32位为计数器,低32位为协程等待数量
    // 由于32位机无法保证64位数据的原子操作,采用两个独立的32位变量分别表示计数器和协程等待数量
    // 同时采用32位存储sema信号量,因此是长度为3的数组
    state1 [3]uint32
}

// 因为涉及高位对齐还是低位对齐,因此采用统一的state()函数获取wg的状态和信号量
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}