避免Go语言sync.WaitGroup的常见错误

1 阅读4分钟

在Go语言的并发编程中,sync.WaitGroup是一个非常重要的同步原语,用于等待一组goroutine完成。然而,很多开发者在使用WaitGroup时容易犯一些常见错误。本文将深入分析这些错误并提供最佳实践。

1. WaitGroup基础回顾

sync.WaitGroup用于等待一组goroutine完成。它提供了三个主要方法:

var wg sync.WaitGroup

wg.Add(1)      // 增加计数器
wg.Done()      // 减少计数器
wg.Wait()      // 阻塞等待计数器归零

2. 常见错误及解决方案

错误1:Add()调用时机不当

错误示例:

var wg sync.WaitGroup

for i := 0; i < 5; i++ {
    go func() {
        wg.Add(1)  // 错误:在goroutine内部调用Add()
        defer wg.Done()
        // 工作逻辑
    }()
}

wg.Wait()  // 可能提前返回或死锁

问题分析:

  • wg.Wait()可能在所有wg.Add(1)执行前就返回
  • 导致部分goroutine未被等待就程序结束

正确做法:

var wg sync.WaitGroup

for i := 0; i < 5; i++ {
    wg.Add(1)  // 在启动goroutine前调用Add()
    go func(id int) {
        defer wg.Done()
        fmt.Printf("Goroutine %d 完成\n", id)
    }(i)
}

wg.Wait()

错误2:Done()调用遗漏

错误示例:

var wg sync.WaitGroup

wg.Add(1)
go func() {
    // 忘记调用wg.Done()
    fmt.Println("工作完成")
}()

wg.Wait()  // 永远阻塞,死锁

解决方案: 使用defer确保Done()被调用:

var wg sync.WaitGroup

wg.Add(1)
go func() {
    defer wg.Done()  // 使用defer确保调用
    fmt.Println("工作完成")
}()

wg.Wait()

错误3:负计数器值

错误示例:

var wg sync.WaitGroup

wg.Add(-1)  // panic: sync: negative WaitGroup counter
wg.Wait()

问题分析: WaitGroup的计数器不能为负数,否则会panic。

正确做法:

var wg sync.WaitGroup

// 只能使用正数或零
wg.Add(1)
go func() {
    defer wg.Done()
    // 工作逻辑
}()

wg.Wait()

错误4:重复调用Done()

错误示例:

var wg sync.WaitGroup

wg.Add(1)
go func() {
    wg.Done()
    wg.Done()  // panic: sync: negative WaitGroup counter
}()

wg.Wait()

解决方案: 确保每个goroutine只调用一次Done():

var wg sync.WaitGroup

wg.Add(1)
go func() {
    defer wg.Done()  // 只调用一次
    // 工作逻辑
}()

wg.Wait()

错误5:WaitGroup复用问题

错误示例:

var wg sync.WaitGroup

wg.Add(1)
go func() {
    defer wg.Done()
    fmt.Println("第一轮")
}()

wg.Wait()

// 尝试复用同一个WaitGroup
wg.Add(1)  // panic: sync: WaitGroup misuse: Add called concurrently with Wait
go func() {
    defer wg.Done()
    fmt.Println("第二轮")
}()

wg.Wait()

问题分析: WaitGroup在Wait()返回后不能立即复用,需要等待所有Done()完成。

正确做法:

// 方案1:使用新的WaitGroup
var wg1 sync.WaitGroup
wg1.Add(1)
go func() {
    defer wg1.Done()
    fmt.Println("第一轮")
}()
wg1.Wait()

var wg2 sync.WaitGroup
wg2.Add(1)
go func() {
    defer wg2.Done()
    fmt.Println("第二轮")
}()
wg2.Wait()

// 方案2:使用WaitGroup池
type WaitGroupPool struct {
    wg sync.WaitGroup
}

func (p *WaitGroupPool) Run(f func()) {
    p.wg.Add(1)
    go func() {
        defer p.wg.Done()
        f()
    }()
}

func (p *WaitGroupPool) Wait() {
    p.wg.Wait()
}

3. 最佳实践

实践1:使用包装函数

创建一个包装函数来简化WaitGroup的使用:

func RunAndWait(tasks ...func()) {
    var wg sync.WaitGroup
    
    for _, task := range tasks {
        wg.Add(1)
        go func(t func()) {
            defer wg.Done()
            t()
        }(task)
    }
    
    wg.Wait()
}

// 使用示例
RunAndWait(
    func() { fmt.Println("任务1") },
    func() { fmt.Println("任务2") },
    func() { fmt.Println("任务3") },
)

实践2:结合Context使用

在需要超时控制或取消的场景中,结合Context使用:

func ProcessWithTimeout(ctx context.Context, tasks []func()) error {
    var wg sync.WaitGroup
    
    for _, task := range tasks {
        wg.Add(1)
        go func(t func()) {
            defer wg.Done()
            t()
        }(task)
    }
    
    done := make(chan struct{})
    go func() {
        wg.Wait()
        close(done)
    }()
    
    select {
    case <-done:
        return nil
    case <-ctx.Done():
        return ctx.Err()
    }
}

// 使用示例
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

err := ProcessWithTimeout(ctx, []func(){
    func() { time.Sleep(2 * time.Second); fmt.Println("任务1") },
    func() { time.Sleep(3 * time.Second); fmt.Println("任务2") },
})
if err != nil {
    fmt.Println("处理超时:", err)
}

实践3:错误收集

在并发任务中收集错误:

func RunTasks(tasks []func() error) []error {
    var wg sync.WaitGroup
    var mu sync.Mutex
    var errors []error
    
    for _, task := range tasks {
        wg.Add(1)
        go func(t func() error) {
            defer wg.Done()
            if err := t(); err != nil {
                mu.Lock()
                errors = append(errors, err)
                mu.Unlock()
            }
        }(task)
    }
    
    wg.Wait()
    return errors
}

// 使用示例
errs := RunTasks([]func() error{
    func() error { return nil },
    func() error { return fmt.Errorf("错误1") },
    func() error { return fmt.Errorf("错误2") },
})

if len(errs) > 0 {
    fmt.Println("发生错误:", errs)
}

4. 实际应用场景

场景1:批量HTTP请求

func fetchURLs(urls []string) ([]string, error) {
    var wg sync.WaitGroup
    var mu sync.Mutex
    results := make([]string, len(urls))
    errors := make([]error, len(urls))
    
    for i, url := range urls {
        wg.Add(1)
        go func(i int, url string) {
            defer wg.Done()
            resp, err := http.Get(url)
            if err != nil {
                mu.Lock()
                errors[i] = err
                mu.Unlock()
                return
            }
            defer resp.Body.Close()
            
            body, err := io.ReadAll(resp.Body)
            if err != nil {
                mu.Lock()
                errors[i] = err
                mu.Unlock()
                return
            }
            
            mu.Lock()
            results[i] = string(body)
            mu.Unlock()
        }(i, url)
    }
    
    wg.Wait()
    
    // 检查错误
    for _, err := range errors {
        if err != nil {
            return nil, err
        }
    }
    
    return results, nil
}

场景2:并发文件处理

func processFiles(files []string) error {
    var wg sync.WaitGroup
    var mu sync.Mutex
    var firstError error
    
    for _, file := range files {
        wg.Add(1)
        go func(filename string) {
            defer wg.Done()
            
            // 处理文件
            err := processFile(filename)
            if err != nil {
                mu.Lock()
                if firstError == nil {
                    firstError = err
                }
                mu.Unlock()
            }
        }(file)
    }
    
    wg.Wait()
    return firstError
}

func processFile(filename string) error {
    // 文件处理逻辑
    fmt.Printf("处理文件: %s\n", filename)
    return nil
}

5. 性能优化建议

优化1:避免过度并发

func processWithLimit(tasks []func(), maxGoroutines int) {
    var wg sync.WaitGroup
    sem := make(chan struct{}, maxGoroutines)
    
    for _, task := range tasks {
        wg.Add(1)
        sem <- struct{}{}  // 获取信号量
        
        go func(t func()) {
            defer func() {
                <-sem  // 释放信号量
                wg.Done()
            }()
            t()
        }(task)
    }
    
    wg.Wait()
}

优化2:使用工作池模式

type WorkerPool struct {
    tasks chan func()
    wg    sync.WaitGroup
}

func NewWorkerPool(numWorkers int) *WorkerPool {
    pool := &WorkerPool{
        tasks: make(chan func()),
    }
    
    for i := 0; i < numWorkers; i++ {
        pool.wg.Add(1)
        go func() {
            defer pool.wg.Done()
            for task := range pool.tasks {
                task()
            }
        }()
    }
    
    return pool
}

func (p *WorkerPool) Submit(task func()) {
    p.tasks <- task
}

func (p *WorkerPool) Shutdown() {
    close(p.tasks)
    p.wg.Wait()
}

// 使用示例
pool := NewWorkerPool(10)
for i := 0; i < 100; i++ {
    pool.Submit(func() {
        fmt.Println("处理任务")
    })
}
pool.Shutdown()

6. 总结

使用sync.WaitGroup时需要注意以下要点:

  1. Add()调用时机:在启动goroutine之前调用
  2. Done()保证调用:使用defer确保Done()被调用
  3. 避免负计数器:Add()只能使用正数
  4. 防止重复Done():每个goroutine只调用一次Done()
  5. 正确复用WaitGroup:Wait()后需要重新初始化或使用新的WaitGroup
  6. 结合Context使用:在需要超时控制的场景中
  7. 错误收集:使用互斥锁保护错误收集
  8. 控制并发数:避免创建过多goroutine

通过遵循这些最佳实践,可以编写出更加健壮和高效的并发Go程序。