[Golang早读] 什么是sync.errgroup

268 阅读4分钟

什么是sync.errgroup

Go团队在实验仓库中添加了一个名为sync.errgroup的新软件包。 sync.ErrGroupsync.WaitGroup功能的基础上,增加了错误传递,以及在发生不可恢复的错误时取消整个goroutine集合,或者等待超时

获取方法:

go get golang.org/x/sync

为什么要有sync.errgroup

go支持并发,一般采用的是 channel 、 sync.WaitGroup 、context,来实现各个协程之间的流程控制和消息传递。 但是对于开启的成千上万的协程,如果在每个协程内都自行去打印 错误日志的话,会造成日志分散,不好分析。 所以我们要实现一种能统一处理各个协程错误的工具

源码分析

type Group struct {
   cancel func()
   
   wg sync.WaitGroup
   
   sem chan token
   
   errOnce sync.Once
   err     error
}

Group是一个结构体,由5个部分组成:

  • cancel 一个取消的函数,主要来包装context.WithCancel的CancelFunc
  • wg 借助于sync.WaitGroup实现的
  • sem 借助channel控制最大并发的goroutine数量,类似于令牌桶的作用
  • errOnce 使用sync.Once保证只输出第一个err
  • err 记录错误的信息

errgroup提供了4个方法

func (g *Group) Wait() error {}
func (g *Group) Go(f func() error) {}
func (g *Group) TryGo(f func() error) bool {}
func (g *Group) SetLimit(n int) {}

通过函数g.Go(func() error) {}调用errgroup,启动goroutine

func (g *Group) Go(f func() error) {
   if g.sem != nil {
      g.sem <- token{}   //通过channel控制最多允许有n个goroutine执行,当达到最大值后阻塞goroutine的运行
   }

   g.wg.Add(1)   // 和sync.WaitGroup一样,每执行一个新的g,通过add方法 加1
   go func() {
      defer g.done()   // goroutine执行结束后 调用 Done方法,减1

      if err := f(); err != nil {    // 执行传入的匿名函数
         // 如果匿名函数返回错误,会记录错误信息。
         //注意这里用的 once.Do,只执行一次,仅会记录第一个出现的err
         g.errOnce.Do(func() {  
            g.err = err
            if g.cancel != nil {  // 如果初始化的有 cancel 函数,会调用 cancel退出
               g.cancel()
            }
         })
      }
   }()
}

通过调用SetLimit()方法设置errgroup中goroutine的数量最多为n

//SetLimit将该组中运行的goroutine数量限制为最多n个。负值表示不限制。
//任何对Go方法的后续调用都将阻塞,直到它可以添加一个执行的goroutine而不超过配置的限制。
//当组中的任何goroutine处于运行状态时,不能修改该限制
func (g *Group) SetLimit(n int) {
   if n < 0 {
      g.sem = nil
      return
   }
   if len(g.sem) != 0 {
      panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
   }
   g.sem = make(chan token, n)  //创建容量为n的令牌桶
}

TryGo()方法实现了安全的调用errgroup的方式

//只有当组中活动goroutine的数量当前低于配置的限制时,TryGo才会在新的goroutine中调用给定函数。
//返回值报告是否启动了goroutine
func (g *Group) TryGo(f func() error) bool {
   if g.sem != nil {
      //通过select监听channel
      select {
      case g.sem <- token{}:
         // Note: this allows barging iff channels in general allow barging.
      default:
         return false
      }
   }

   //与Go()方法实现一致
   ...
   
   return true  //成功执行
}

wait()方法就是对WaitGroup.wait()的进一步封装,在阻塞等待goroutine执行完成的同时,判断cancel()是否被执行退出

func (g *Group) Wait() error {
   g.wg.Wait()   // 和 WaitGroup 一样,在主线程调用 wait 方法,阻塞等待所有g执行完成
   if g.cancel != nil {  // 如果初始化了 cancel 函数,就执行
      g.cancel()
   }
   return g.err   // 返回第一个出现的err信息
}

简单使用

结合context来进行使用

func main() {
    ctx, cancel := context.WithCancel(context.Background())
    group, errCtx := errgroup.WithContext(ctx)

    for index := 0; index < 3; index++ {
        indexTemp := index

        // 新建子协程
        group.Go(func() error {
            fmt.Printf("indexTemp=%d \n", indexTemp)
            if indexTemp == 0 { // 第一个协程
                fmt.Println("indexTemp == 0 start ")
                fmt.Println("indexTemp == 0 end")
            } else if indexTemp == 1 { // 第二个协程
                fmt.Println("indexTemp == 1 start")
                //这里一般都是某个协程发生异常之后,调用cancel()
                //这样别的协程就可以通过errCtx获取到err信息,以便决定是否需要取消后续操作
                cancel() // 第二个协程异常退出
                fmt.Println("indexTemp == 1 err ")
            } else if indexTemp == 2 {
                fmt.Println("indexTemp == 2 begin")

                // 休眠1秒,用于捕获子协程2的出错
                time.Sleep(1 * time.Second)

                //检查 其他协程已经发生错误,如果已经发生异常,则不再执行下面的代码
                err := CheckGoroutineErr(errCtx) // 第三个协程感知第二个协程是否正常
                if err != nil {
                    return err
                }
                fmt.Println("indexTemp == 2 end ")
            }
            return nil
        })
    }

    // 捕获err
    err := group.Wait()
    if err == nil {
        fmt.Println("都完成了")
    } else {
        fmt.Printf("get error:%v", err)
    }
}

//校验是否有协程已发生错误
func CheckGoroutineErr(errContext context.Context) error {
    select {
    case <-errContext.Done():
        return errContext.Err()
    default:
        return nil
    }
}