golang Errgroup 用法以及实现的场景

75 阅读1分钟

我在docker compose 中看到这个用法

代码:


func (s *composeService) Ps(ctx context.Context, projectName string, options api.PsOptions) ([]api.ContainerSummary, error) {

projectName = strings.ToLower(projectName)

oneOff := *oneOffExclude*

**if options.All {

oneOff = *oneOffInclude*

**}

containers, err := s.getContainers(ctx, projectName, oneOff, options.All, options.Services...)

if err != nil {

return nil, err

}


if len(options.Services) != 0 {

containers = containers.filter(isService(options.Services...))

}

summary := make([]api.ContainerSummary, len(containers))

eg, ctx := errgroup.WithContext(ctx)

for i, container := range containers {

i, container := i, container
//

eg.Go(func() error {

// ....

return nil

})

}

return summary, eg.Wait()

}

于是发现了ErrGroup这个用法

结构体

type Group struct {
    cancel func(error)

    wg sync.WaitGroup

    sem chan token
    errOnce sync.Once
    err     error
}

分别方法有:

//返回一个结构体
func WithContext(ctx context.Context) (*Group, context.Context) {
  
    ctx, cancel := withCancelCause(ctx)
    return &Group{cancel: cancel}, ctx
}

重点 是 go 和try go

func (g *Group) Go(f func() error) {
    //这里会被 g.sem长度所限制
    if g.sem != nil {
       g.sem <- token{}
    }

    g.wg.Add(1)
    go func() {
       //线程结束了会回收
       defer g.done()

       if err := f(); err != nil {
          g.errOnce.Do(func() {
             g.err = err
             if g.cancel != nil {
                g.cancel(g.err)
             }
          })
       }
    }()
}

这是done 基本代码

func (g *Group) done() {
    if g.sem != nil {
       <-g.sem
    }
    g.wg.Done()
}
//try go 本质就是如果超过限制了就会有一层保护,不致于被堵塞
func (g *Group) TryGo(f func() error) bool {
    if g.sem != nil {
       select {
       case g.sem <- token{}:
          // Note: this allows barging iff channels in general allow barging.
       default:
          return false
       }
    }

    g.wg.Add(1)
    go func() {
       defer g.done()

       if err := f(); err != nil {
          g.errOnce.Do(func() {
             g.err = err
             if g.cancel != nil {
                g.cancel(g.err)
             }
          })
       }
    }()
    return true
}

也是依靠这个来去限制 sem,这个就是用于维护当前的‘线程池’

func (g *Group) SetLimit(n int) {
    if n < 0 {
       g.sem = nil
       return
    }
    if len(g.sem) != 0 {
       panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
    }
    g.sem = make(chan token, n)
}

具体案例:

package main

  


import (

    "context"

    "errors"

    "fmt"

    "golang.org/x/sync/errgroup"

    "time"

)

  


// 一个模拟的长时间运行的任务

func longRunningTask(ctx context.Context, taskID int) error {

    select {

    case <-time.After(time.Duration(taskID) * time.Second):

        // 模拟某个任务失败

        if taskID == 2 {

            return errors.New("task 2 failed")

        }

        fmt.Printf("Task %d completed\n", taskID)

        return nil

    case <-ctx.Done():

        fmt.Printf("Task %d canceled\n", taskID)

        return ctx.Err()

    }

}

  


func main() {

    // 创建一个带有上下文的 errgroup.Group

    ctx := context.Background()

    eg, ctx := errgroup.WithContext(ctx)

  


    // 启动多个并发任务

    for i := 1; i <= 5; i++ {

        taskID := i // 避免闭包捕获问题

        eg.Go(func() error {

            return longRunningTask(ctx, taskID)

        })

    }

  


    // 等待所有任务完成或任何一个任务失败

    if err := eg.Wait(); err != nil {

        fmt.Printf("Error: %v\n", err)

    } else {

        fmt.Println("All tasks completed successfully.")

    }

}