Golang 标准库 tips -- waitgroup

1,136 阅读13分钟

WaitGroup 用于线程同步,很多场景下为了提高并发需要开多个协程执行,但是又需要等待多个协程的结果都返回的情况下才进行后续逻辑处理,这种情况下可以通过 WaitGroup 提供的方法阻塞主线程的执行,直到所有的 goroutine 执行完成。 本文目录结构:

WaitGroup 不能被值拷贝
Add 需要在 Wait 之前调用
使用 channel 实现 WaitGroup 的功能
Add 和 Done 数量问题
WaitGroup 和 channel 控制并发数
WaitGroup 和 channel 实现提前退出
WaitGroup 和 channel 返回错误
使用 ErrGroup 返回错误 
使用 ErrGroup 实现提前退出
改善版的 Errgroup

WaitGroup 不能被值拷贝

wg 作为一个参数传递的时候,我们在函数中操作的时候还是操作的一个拷贝的变量,对于原来的 wg 是不会改变。 这一点可以从 WaitGroup 实现的源码定义的 struct 能能看出来,WaitGroup 的 struct 就两个字段,第一个字段就是 noCopy,表明这个结构体是不希望直接被复制的。noCopy 是的实现是一个空的 struct{},主要的作用是嵌入到结构体中作为辅助 vet 工具检查是否通过 copy 赋值这个 WaitGroup 实例,如果有值拷贝的情况,会被检测出来,我们一般的 lint 工具也都能检测出来。 在某些情况下,如果 WaitGroup 需要作为参数传递到其他的方法中,一定需要使用指针类型进行传递。

type WaitGroup struct {
    noCopy noCopy

    // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
    // 64-bit atomic operations require 64-bit alignment, but 32-bit
    // compilers do not ensure it. So we allocate 12 bytes and then use
    // the aligned 8 bytes in them as state, and the other 4 as storage
    // for the sema.
    state1 [3]uint32
}

可以用以下一个例子来说明:

// 错误的用法,函数传递 wg 是值拷贝
func main() {
    wg := sync.WaitGroup{}

    wg.Add(10)

    for i := 0; i < 10; i++ {
        go func(i int) {
            do(i, wg)
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

func do(i int, wg sync.WaitGroup) { // wg 值拷贝,会导致程序
    fmt.Println(i)
    wg.Done()
}

// 正确的用法,waitgroup 参数传递使用指针的形式
func main() {
    wg := sync.WaitGroup{}

    wg.Add(10)

    for i := 0; i < 10; i++ {
        go func(i int) {
            do(i, &wg)
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

func do(i int, wg *sync.WaitGroup) {
    fmt.Println(i)
    wg.Done()
}

Add 需要在 Wait 之前调用

WaitGroup 结构体提供了三个方法,Add、Done、Wait,Add 的作用是用来设置WaitGroup的计数值(子goroutine的数量);Done的作用用来将 WaitGroup 的计数值减 1,其实就是调用Add(-1);Wait 的作用是检测 WaitGroup 计数器的值是否为 0,如果为 0 表示所有的 goroutine 都运行完成,否则会阻塞等待计数器的值为0(所有的 groutine都执行完成)之后才运行后面的代码。 所以在 WaitGroup 调用的时候一定要保障 Add 函数在 Wait 函数之前执行,否则可能会导致 Wait 方法没有等到所有的结果运行完成而被执行完。也就是我们不能在 Grountine 中来执行 Add 和 Done,这样可能当前 Grountine 来不及运行,外层的 Wait 函数检测到满足条件然后退出了。

func main() {
    wg := sync.WaitGroup{}
    wg.Wait() // 直接调用 Wait() 方法是不会阻塞的,因为 wg 中 goroutine 计数器的值为 0
    fmt.Println("success")
}
// 错误的写法,在 goroutine 中进行 Add(1) 操作。
// 可能在这些 goroutine 还没来得及 Add(1) 就已经执行 Wait 操作了
func main() {
    wg := sync.WaitGroup{}

    for i := 0; i < 10; i++ {
        go func(i int) {
            wg.Add(1)
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

// 打印的结果,不是我们预期的打印 10 个元素之后再打印 success,而是会随机打印其中的一部分
success
1
0
5
2

// 正确的写法一
func main() {
    wg := sync.WaitGroup{}
    wg.Add(10) // 在 groutine 外层先把需要运行的 goroutine 的数量设置好,保障比 Wait 函数先执行

    for i := 0; i < 10; i++ {
        go func(i int) {
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

// 正确的写法二
func main() {
    wg := sync.WaitGroup{}

    for i := 0; i < 10; i++ {
        wg.Add(1) // 保障比 Wait 函数先执行
        go func(i int) {
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

使用 channel 实现 WaitGroup 的功能

如果想要实现主线程中等待多个协程的结果都返回的情况下才进行后续调用,也可以通过带缓存区的 channel 来实现,实现的思路是需要先知道等待 groutine 的运行的数量,然后初始化一个相同缓存区数量的 channel,在 groutine 运行结束之后往 channel 中放入一个值,并在主线程中阻塞监听获取 channel 中的值全部返回。

func main() {
    numGroutine := 10
    ch := make(chan struct{}, numGroutine)

    for i := 0; i < numGroutine; i++ {
        go func(i int) {
            fmt.Println(i)
            ch <- struct{}{}
        }(i)
    }

    for i := 0; i < numGroutine; i++ {
        <-ch
    }

    fmt.Println("success")
}

// 打印结果:
7
5
3
1
9
0
4
2
6
8
success

Add 和 Done 数量问题

需要保障 Add 的数量和 Done 的数量一致,如果 Add 数量小于 Done 数量的情况下,调用 Wait 方法会检测到计数器的值为负数,程序会报 panic;如果 Add 数量大于 Done 的数量,会导致 Wait 循环阻塞后面的代码得不到执行。 Add 数量小于 Done 数量:

func main() {
    wg := sync.WaitGroup{}
    wg.Add(1) // Add 数量小于 Done 数量

    for i := 0; i < 10; i++ {
        go func(i int) {
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

// 运行结果,有两种结果
结果一:打印部分输出然后退出,这种情况是因为 Done 执行了一个只会,Wait 检测到刚好满足条件然后退出了
1
success
9
5

结果二:执行 Wait 函数的时候,计数器的值已经是负数了
0
9
3
panic: sync: negative WaitGroup counter

Add 数量大于 Done 数量:

func main() {
    wg := sync.WaitGroup{}
    wg.Add(20)

    for i := 0; i < 10; i++ {
        go func(i int) {
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

// 执行结果:deadlock
0
9
3
7
8
1
4
2
6
5
fatal error: all goroutines are asleep - deadlock!

WaitGroup 和 channel 控制并发数

使用 waitgroup 可以控制一组 groutine 同时运行并等待结果返回之后再进行后续操作,虽然 groutine 对资源消耗比较小,但是大量的 groutine 并发对系统的压力还是比较大,所以这种情况如果需要控制 waitgroup 中 groutine 并发数量控制,就可以使用缓存的 channel 控制同时并发的 groutine 数量。

func main() {
    wg := sync.WaitGroup{}
    wg.Add(200)

    ch := make(chan struct{}, 10) // 控制最大并发数是 10
 
    for i := 0; i < 200; i++ {
        ch <- struct{}{}
        go func(i int) {
            fmt.Println(i)
            wg.Done()
            <-ch
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

根据使用 channel 实现 WaitGroup 的功能的思路,我们上面的代码也可以通过两个 channel 进行改造来实现。

func main() {
    numGroutine := 200 // 运行的 groutine 总数量
    numParallel := 10  // 并发的 groutine 数量

    chTotal := make(chan struct{}, numGroutine)
    chParallel := make(chan struct{}, numParallel)

    for i := 0; i < 200; i++ {
        chTotal <- struct{}{}
        go func(i int) {
            fmt.Println(i)
            <-chTotal
            chParallel <- struct{}{}
        }(i)
    }

    for i := 0; i < numGroutine; i++ {
        <-chParallel
    }
    fmt.Println("success")
}

WaitGroup 和 channel 实现提前退出

用 WaitGroup 协调一组并发 goroutine 的做法很常见,但 WaitGroup 本身也有其不足: WaitGroup 必须要等待控制的一组 goroutine 全部返回结果之后才往下运行,但是有的情况下我们希望能快速失败,也就是这一组 goroutine 中只要有一个失败了,那么就不应该等到所有 goroutine 结束再结束任务,而是提前结束以避免资源浪费,这个时候就可以使用 channel 配合 WaitGroup 实现提前退出的效果。

func main() {
    wg := sync.WaitGroup{}
    wg.Add(10)

    ch := make(chan struct{}) // 使用一个 channel 传递退出信号

    for i := 0; i < 10; i++ {
        go func(i int) {
            time.Sleep(time.Duration(i) * time.Second)
            fmt.Println(i)
            if i == 2 { // 检测到 i==2 则提前退出
                ch <- struct{}{}
            }
            wg.Done()
        }(i)
    }

    go func() {
        wg.Wait()        // wg.Wait 执行之后表示所有的 groutine 都已经执行完成了,而且没有 groutine 往 ch 传递退出信号
        ch <- struct{}{} // 需要传递一个信号,不然主线程会一直阻塞
    }()

    <-ch // 阻塞等待收到退出信号之后往下执行

    fmt.Println("success")
}

// 打印结果
0
1
2
success

WaitGroup 和 channel 返回错误

WaitGroup 除了不能快速失败之外还有一个问题就是不能在主线程中获取到 groutine 出错时返回的错误,这种情况下就可以用到 channel 进行错误传递,在主线程中获取到错误。

// 案例一:groutine 中只要有一个失败了则返回 err 并且回到主协程运行后续代码
func main() {
    wg := sync.WaitGroup{}
    wg.Add(10)

    ch := make(chan error) // 使用一个 channel 传递退出信号

    for i := 0; i < 10; i++ {
        go func(i int) {
            time.Sleep(time.Duration(i) * time.Second)
            if i == 2 { // 检测到 i==2 则提前退出
                ch <- fmt.Errorf("i can't be 2")
                close(ch)
                return
            }
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    go func() {
        wg.Wait() // wg.Wait 执行之后表示所有的 groutine 都已经执行完成了,而且没有 groutine 往 ch 传递退出信号
        ch <- nil // 需要传递一个 nil error,不然主线程会一直阻塞
        close(ch)
    }()

    err := <-ch
    fmt.Println(err.Error())
}

// 运行结果:
/*
0
1
i can't be 2
*/

// 案例二:等待所有的 groutine 都运行完成再回到主线程并捕获所有的 error
func main() {
    wg := sync.WaitGroup{}
    wg.Add(10)

    ch := make(chan error, 10) // 设置和 groutine 数量一致,可以缓冲最多 10 个 error

    for i := 0; i < 10; i++ {
        go func(i int) {
            defer func() {
                wg.Done()
            }()
            time.Sleep(time.Duration(i) * time.Second)
            if i == 2 {
                ch <- fmt.Errorf("i can't be 2")
                return
            }
            if i == 3 {
                ch <- fmt.Errorf("i can't be 3")
                return
            }
            fmt.Println(i)
        }(i)
    }

    wg.Wait() // wg.Wait 执行之后表示所有的 groutine 都已经执行完成了
    close(ch) // 需要 close channel,不然主线程会阻塞

    for err := range ch {
        fmt.Println(err.Error())
    }
}

// 打印结果:
0
1
4
5
6
7
8
9
i can't be 2
i can't be 3

使用 ErrGroup 返回错误

正是由于 WaitGroup 有以上说的一些缺点,Go 团队在实验仓库(golang.org/x)增加了 errgroup.Group 的功能,相比 WaitGroup 增加了错误传递、快速失败、超时取消等功能,相对于通过 channel 和 WaitGroup 组合实现这些功能更方便,也更加推荐。 errgroup.Group 结构体也比较简单,在 sync.WaitGroup 的基础之上包装了一个 error 以及一个 cancel 方法,err 的作用是在 goroutine 出错的时候能够返回,cancel 方法的作用是在出错的时候快速失败。 errgroup.Group 对外暴露了3个方法,WithContext、Go、Wait,没有了 Add、Done 方法,其实 Add 和 Done 是在包装在了 errgroup.Group 的 Go 方法里面了,我们执行的时候不需要关心。

// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
    cancel func()

    wg sync.WaitGroup

    errOnce sync.Once
    err     error
}

func WithContext(ctx context.Context) (*Group, context.Context) {
    ctx, cancel := context.WithCancel(ctx)
    return &Group{cancel: cancel}, ctx
}

// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
    g.wg.Wait()
    if g.cancel != nil {
        g.cancel()
    }
    return g.err
}

// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
    g.wg.Add(1)

    go func() {
        defer g.wg.Done()

        if err := f(); err != nil {
            g.errOnce.Do(func() {
                g.err = err
                if g.cancel != nil {
                    g.cancel()
                }
            })
        }
    }()
}

以下是使用 errgroup.Group 来实现返回 goroutine 错误的例子:

func main() {
    eg := errgroup.Group{}

    for i := 0; i < 10; i++ {
        i := i // 这里需要进行赋值操作,不然会有闭包问题,eg.Go 执行的 groutine 会引用 for 循环的 i
        eg.Go(func() error {
            if i == 2 {
                return fmt.Errorf("i can't be 2")
            }
            fmt.Println(i)
            return nil
        })
    }

    if err := eg.Wait(); err != nil {
        fmt.Println(err.Error())
    }
}

// 打印结果
9
6
7
8
3
4
1
5
0
i can't be 2

需要注意的一点是通过 errgroup.Group 来返回 err 只会返回其中一个 groutine 的错误,而且是最先返回 err 的 groutine 的错误,这一点是通过 errgroup.Group 的 errOnce 来实现的。

使用 ErrGroup 实现提前退出

使用 errgroup.Group 实现提前退出也比较简单,调用 errgroup.WithContext 方法获取 errgroup.Group 对象以及一个可以取消的 WithCancel 的 context,并且将这个 context 方法传入到所有的 groutine 中,并在 groutine 中使用 select 监听这个 context 的 Done() 事件,如果监听到了表明接收到了 cancel 信号,然后退出 groutine 即可。需要注意的是 eg.Go 一定要返回一个 err 才会触发 errgroup.Group 执行 cancel 方法。

// 案例一:通过 groutine 显示返回 err 触发 errgroup.Group 底层的 cancel 方法
func main() {
    ctx := context.Background()
    eg, ctx := errgroup.WithContext(ctx)

    for i := 0; i < 10; i++ {
        i := i // 这里需要进行赋值操作,不然会有闭包问题,eg.Go 执行的 groutine 会引用 for 循环的 i
        eg.Go(func() error {
            select {
            case <-ctx.Done():
                return ctx.Err()
            case <-time.After(time.Duration(i) * time.Second):
            }
            if i == 2 {
                return fmt.Errorf("i can't be 2") // 需要返回 err 才会导致 eg 的 cancel 方法
            }
            fmt.Println(i)
            return nil
        })
    }

    if err := eg.Wait(); err != nil {
        fmt.Println(err.Error())
    }
}

// 打印结果:
0
1
i can’t be 2

// 案例二:通过显示调用 cancel 方法通知到各个 groutine 退出
func main() {
    ctx, cancel := context.WithCancel(context.Background())
    eg, ctx := errgroup.WithContext(ctx)

    for i := 0; i < 10; i++ {
        i := i // 这里需要进行赋值操作,不然会有闭包问题,eg.Go 执行的 groutine 会引用 for 循环的 i
        eg.Go(func() error {
            select {
            case <-ctx.Done():
                return ctx.Err()
            case <-time.After(time.Duration(i) * time.Second):
            }
            if i == 2 {
                cancel()
                return nil // 可以不用返回 err,因为手动触发了 cancel 方法
                //return fmt.Errorf("i can't be 2")
            }
            fmt.Println(i)
            return nil
        })
    }

    if err := eg.Wait(); err != nil {
        fmt.Println(err.Error())
    }
}

// 打印结果:
0
1
context canceled


// 案例三:
// 基于 errgroup 实现一个 http server 的启动和关闭 ,以及 linux signal 信号的注册和处理,要保证能够 一个退出,全部注销退出
// https://lailin.xyz/post/go-training-week3-errgroup.html
func main() {
    g, ctx := errgroup.WithContext(context.Background())

    mux := http.NewServeMux()
    mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
        w.Write([]byte("pong"))
    })

    // 模拟单个服务错误退出
    serverOut := make(chan struct{})
    mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) {
        serverOut <- struct{}{}
    })

    server := http.Server{
        Handler: mux,
        Addr:    ":8080",
    }

    // g1
    // g1 退出了所有的协程都能退出么?
    // g1 退出后, context 将不再阻塞,g2, g3 都会随之退出
    // 然后 main 函数中的 g.Wait() 退出,所有协程都会退出
    g.Go(func() error {
        return server.ListenAndServe()
    })

    // g2
    // g2 退出了所有的协程都能退出么?
    // g2 退出时,调用了 shutdown,g1 会退出
    // g2 退出后, context 将不再阻塞,g3 会随之退出
    // 然后 main 函数中的 g.Wait() 退出,所有协程都会退出
    g.Go(func() error {
        select {
        case <-ctx.Done():
            log.Println("errgroup exit...")
        case <-serverOut:
            log.Println("server will out...")
        }

        timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
        // 这里不是必须的,但是如果使用 _ 的话静态扫描工具会报错,加上也无伤大雅
        defer cancel()

        log.Println("shutting down server...")
        return server.Shutdown(timeoutCtx)
    })

    // g3
    // g3 捕获到 os 退出信号将会退出
    // g3 退出了所有的协程都能退出么?
    // g3 退出后, context 将不再阻塞,g2 会随之退出
    // g2 退出时,调用了 shutdown,g1 会退出
    // 然后 main 函数中的 g.Wait() 退出,所有协程都会退出
    g.Go(func() error {
        quit := make(chan os.Signal, 0)
        signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)

        select {
        case <-ctx.Done():
            return ctx.Err()
        case sig := <-quit:
            return errors.Errorf("get os signal: %v", sig)
        }
    })

    fmt.Printf("errgroup exiting: %+v\n", g.Wait())
}

改善版的 Errgroup

使用 errgroup.Group 的 WithContext 我们注意到在返回 eg 对象的同时还会返回另外一个可以取消的 context 对象,这个 context 对象的功能就是用来传递到 eg 需要同步的 groutine 中有一个发生错误时取消整个同步的 groutine,但是有不少同学可能会不经意将这个 context 传到其他的非 eg 同步的业务代码groutine 中,这样会导致非关联的业务代码莫名其妙的收到 cancel 信息,类似如下的写法:

func main() {
    ctx := context.Background()
    eg, ctx := errgroup.WithContext(ctx)

    for i := 0; i < 10; i++ {
        i := i // 这里需要进行赋值操作,不然会有闭包问题,eg.Go 执行的 groutine 会引用 for 循环的 i
        eg.Go(func() error {
            select {
            case <-ctx.Done():
                return ctx.Err()
            case <-time.After(time.Duration(i) * time.Second):
            }
            if i == 2 {
                return fmt.Errorf("i can't be 2") // 需要返回 err 才会导致 eg 的 cancel 方法
            }
            fmt.Println(i)
            return nil
        })
    }

    if err := eg.Wait(); err != nil {
        fmt.Println(err.Error())
    }

    OtherLogic(ctx)
}

func OtherLogic(ctx context.Context) {
    // 这里的 context 用了创建 eg 返回的 context,这个 context 可能会往后面更多的 func 中传递
    // 如果在该方法或者后面的 func 中有对 context 监听取消型号,会导致这些 context 被取消了
}

另外不管是 WaitGroup 还是 errgroup.Group 都不支持控制最大并发限制以及 panic 恢复的功能,因为我们不能保障我们通过创建的 groutine 不会出现异常,如果没有在创建的协程中捕获异常,会直接导致整个程序退出,这是非常危险的。 这里推荐一下 bilbil 开源的微服务框架 go-kratos/kratos 自己实现了一个改善版本的 errgroup.Group,其实现的的思路是利用 channel 来控制并发,并且创建 errgroup 的时候不会返回 context 避免 context 往非关联的业务方法中传递。