Golang | 详解sync.WaitGroup

1,887 阅读2分钟

# 参与掘金活动

持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第3天,点击查看活动详情

引言

Golang当作做并发控制的方式有很多,例如 ContextMutexChannelsync.WaitGroup

本文从 sync.WaitGroup的使用场景入手,结合源码对其做了简要分析并介绍了其特性, 最后总结了使用过程中需要注意的点

使用场景

当我们需要等待一组协程的返回之后才能进行接下来的动作时,就需要 sync.WaitGroup 来阻塞这个并发任务,等待其他 goroutine 结束。

使用waitGroup能够让我们的程序充分利用多核的特性,达到并行的效果,加快程序处理的速度

比如在我们常见的生产者消费者模型中,我们使用使用 sync.WaitGroup做协程控制,等待所有生产者和消费者协程结束后,才接着走接下来的步骤.

func main() {
   ch := make(chan int, 9)
   wg := &sync.WaitGroup{}
   wg.Add(producerNums + consumerNums)
   for i := 0; i < producerNums; i++ {
      go func(idx int) {
         defer wg.Done()
         producer(fmt.Sprintf("生产者%d ", idx), 10, ch)
      }(i)
   }
   for i := 0; i < consumerNums; i++ {
      go func(idx int) {
         defer wg.Done()
         consumer(fmt.Sprintf("消费者%d ", idx), ch)
      }(i)
   }
   wg.Wait()
   log.Println("Done")
   return
}

func producer(name string, sum int, ch chan int) {
   i := 0
   for {
      i++
      if i > sum {
         break
      }
      ch <- i
      fmt.Println("producer--", name, ":", i)
   }
}

func consumer(name string, ch chan int) {
   for {
      data, ok := <-ch
      if ok {
         fmt.Println("consumer--", name, ":", data)
      } else {
         return
      }
   }
}

源码分析

整体来说 sync.WaitGroup 的结构和源码比较简单

结构

type WaitGroup struct {
   noCopy noCopy
   
   state1 [3]uint32
}

type noCopy struct{}
func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}

  • 其中 noCopy是一个比较有趣的结构,表示 waitGroup 是一个不可复制的结构,即我们在传递waitGoup时,只能通过指针传递
  • 这样做的好处是当我们使用指针复制了原有的对象时,新旧对象能够使用同一个底层数组,共用同一个指针变量
func test1(){
   wg = &sync.WaitGroup{}
   wg.Add(1)
   wgg := wg
   wgg.Done()
   wgg.Wait()
   wg.Wait()
   log.Println("over")
}

image.png

  • state1 是一个长度为3的uint32数组,分别表示了 被ADD()/Done()方法操作的计数器正在Wait()处阻塞的协程数、以及sema信号量

方法

sync.WaitGroup 对外暴露了三个方法 — Add()Wait()Done() ,还有一个比较重要的私有方法state()用于获得sync.WaitGroup的状态和信号量

state

我们先来看看state()方法

方法解析sync.WaitGroup结构体中的state1, 返回两个指针

其中statep所指的uint64变量中高32位存储了计数器,低32位存储了此时正在Wait处阻塞的goroutine个数, semap指向用于唤醒和等待的信号量

使用了指针强转,提高了数据结构的转换效率

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
   if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
      return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
   } else {
      return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
   }
}

add

其中Done只是调用Add方法,我们来看看Add方法
方法比较简单,整体只做了几件事情

  1. 获得sync.WaitGroup的状态和信号量
  2. 是否还有任务在执行
  • 是,则直接return
  • 否,则将所有等待的goroutine唤醒

同时可以看出,waitGroup不允许counter小于0,否则程序发生panic

func Add(delta int) {
   // 从 state1 字段中取出它的状态和信号量
   statep, semap := wg.state()
  
   state := atomic.AddUint64(statep, uint64(delta)<<32)
   v := int32(state >> 32) //   获取counter
   w := uint32(state) //        获取waiter
    // counter < 0  -> panic 
   if v < 0 {
      panic("sync: negative WaitGroup counter")
   }
   if w != 0 && delta > 0 && v == int32(delta) {
      panic("sync: WaitGroup misuse: Add called concurrently with Wait")
   }
   
   // 还有任务在执行且没有在阻塞等待的goroutine
   if v > 0 || w == 0 {
      return
   }
   
   // countr = 0,即所有的任务都已执行完,将所有等待的goroutine唤醒
   for ; w != 0; w-- {
      runtime_Semrelease(semap, false, 0)
   }
}

wait

wait方法也比较简单

  1. 使用state方法获取waitGroup的状态和信号量
  2. 校验counter是否为0 ,若counter为0 直接返回
  3. counter大于0且不存在等待的goroutine时,等待的协程继续维持阻塞状态
func (wg *WaitGroup) Wait() {
   statep, semap := wg.state()
  
   for {
      state := atomic.LoadUint64(statep)
      v := int32(state >> 32)
      w := uint32(state)
      if v == 0 {
         return
      }
      if atomic.CompareAndSwapUint64(statep, state, state+1) {
         runtime_Semacquire(semap)
         if *statep != 0 {
            panic("sync: WaitGroup is reused before previous Wait has returned")
         }
         return
      }
   }
}

总结

  • sync.WaitGroup 底层使用CAS实现字段值的修改,并没有使用Mutex/RWMutex, 减少了锁竞争,并且CAS是由底层硬件提供支持,效率更高,在这个场景里有点重了
  • sync.WaitGroup 使用信号量控制协程唤醒
  • sync.WaitGroup 使用指针拷贝,新旧对象底层使用相同的内存地址
  • sync.WaitGroup 必须在Wait() 方法返回之后才能被重新使用
  • 还有一个平时可能忽略的点,如果被阻塞的协程过于多,那么会在Wait之后被同时唤醒,有点类型惊群效应