我在docker compose 中看到这个用法
代码:
func (s *composeService) Ps(ctx context.Context, projectName string, options api.PsOptions) ([]api.ContainerSummary, error) {
projectName = strings.ToLower(projectName)
oneOff := *oneOffExclude*
**if options.All {
oneOff = *oneOffInclude*
**}
containers, err := s.getContainers(ctx, projectName, oneOff, options.All, options.Services...)
if err != nil {
return nil, err
}
if len(options.Services) != 0 {
containers = containers.filter(isService(options.Services...))
}
summary := make([]api.ContainerSummary, len(containers))
eg, ctx := errgroup.WithContext(ctx)
for i, container := range containers {
i, container := i, container
//
eg.Go(func() error {
// ....
return nil
})
}
return summary, eg.Wait()
}
于是发现了ErrGroup这个用法
结构体
type Group struct {
cancel func(error)
wg sync.WaitGroup
sem chan token
errOnce sync.Once
err error
}
分别方法有:
//返回一个结构体
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := withCancelCause(ctx)
return &Group{cancel: cancel}, ctx
}
重点 是 go 和try go
func (g *Group) Go(f func() error) {
//这里会被 g.sem长度所限制
if g.sem != nil {
g.sem <- token{}
}
g.wg.Add(1)
go func() {
//线程结束了会回收
defer g.done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel(g.err)
}
})
}
}()
}
这是done 基本代码
func (g *Group) done() {
if g.sem != nil {
<-g.sem
}
g.wg.Done()
}
//try go 本质就是如果超过限制了就会有一层保护,不致于被堵塞
func (g *Group) TryGo(f func() error) bool {
if g.sem != nil {
select {
case g.sem <- token{}:
// Note: this allows barging iff channels in general allow barging.
default:
return false
}
}
g.wg.Add(1)
go func() {
defer g.done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel(g.err)
}
})
}
}()
return true
}
也是依靠这个来去限制 sem,这个就是用于维护当前的‘线程池’
func (g *Group) SetLimit(n int) {
if n < 0 {
g.sem = nil
return
}
if len(g.sem) != 0 {
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
}
g.sem = make(chan token, n)
}
具体案例:
package main
import (
"context"
"errors"
"fmt"
"golang.org/x/sync/errgroup"
"time"
)
// 一个模拟的长时间运行的任务
func longRunningTask(ctx context.Context, taskID int) error {
select {
case <-time.After(time.Duration(taskID) * time.Second):
// 模拟某个任务失败
if taskID == 2 {
return errors.New("task 2 failed")
}
fmt.Printf("Task %d completed\n", taskID)
return nil
case <-ctx.Done():
fmt.Printf("Task %d canceled\n", taskID)
return ctx.Err()
}
}
func main() {
// 创建一个带有上下文的 errgroup.Group
ctx := context.Background()
eg, ctx := errgroup.WithContext(ctx)
// 启动多个并发任务
for i := 1; i <= 5; i++ {
taskID := i // 避免闭包捕获问题
eg.Go(func() error {
return longRunningTask(ctx, taskID)
})
}
// 等待所有任务完成或任何一个任务失败
if err := eg.Wait(); err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Println("All tasks completed successfully.")
}
}