【Go并发编程】CyclicBarrier源码阅读

310 阅读1分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 25 天,点击查看活动详情

CyclicBarrier

CyclicBarrier 是一种同步机制,可以让多个 goroutine 在某个点上相互等待,直到所有 goroutine 都达到该点后才继续执行。在 Go 语言中可以使用 sync.WaitGroup 实现类似的功能,但 CyclicBarrier 具有更多的特性,如可以设置等待 goroutine 数量、可以设置等待超时时间、可以在所有 goroutine 到达指定点后执行回调函数等。

WaitGroup 更适合用在“一个 goroutine 等待一组 goroutine 到达同一个执行点”的场景中,或者是不需要重用的场景中。下面是一个简单例子:

func main() {
   barrier := cyclicbarrier.New(5)
   barrier.GetNumberWaiting()
   ctx := context.Background()
   for i := 0; i < 5; i++ {
      go func(id int) {
         fmt.Printf("Goroutine %d before wait\n", id)
         barrier.Await(ctx)
         fmt.Printf("Goroutine %d after wait\n", id)
      }(i)
   }

   // Wait for a while to demonstrate that the barrier is cyclic
   time.Sleep(2 * time.Second)

   for i := 0; i < 5; i++ {
      go func(id int) {
         fmt.Printf("Goroutine %d before wait\n", id)
         barrier.Await(ctx)
         fmt.Printf("Goroutine %d after wait\n", id)
      }(i)
   }

   // Wait for a while to see the results
   time.Sleep(2 * time.Second)
}

当5个gorouting都到Await时,再一起进行下面的操作,像栅栏一样,所以也叫循环栅栏。

源码阅读

以"github.com/marusama/cyclicbarrier"为例我们一起看一下他源码是如何实现的:

type cyclicBarrier struct {
   parties       int
   barrierAction func() error

   lock  sync.RWMutex
   round *round
}
type round struct {
   count    int           // 此往返的goroutine计数
   waitCh   chan struct{} // wait channel for this roundtrip
   brokeCh  chan struct{} // channel for isBroken broadcast
   isBroken bool          // is barrier broken
}
  • parties:等待者数量
  • barrierAction:
  • lock:RW锁
  • round:包含两个chan:wait和broken,两个状态变量
func New(parties int) CyclicBarrier {
   if parties <= 0 {
      panic("parties must be positive number")
   }
   return &cyclicBarrier{
      parties: parties,
      lock:    sync.RWMutex{},
      round: &round{
         waitCh:  make(chan struct{}),
         brokeCh: make(chan struct{}),
      },
   }
}

New:

  • 检测parties是否>0
  • 创建一个cyclicBarrier

Await等待,直到所有各方都在该屏障上调用Await。

func (b *cyclicBarrier) Await(ctx context.Context) error {
   var (
      ctxDoneCh <-chan struct{}
   )
   if ctx != nil {
      ctxDoneCh = ctx.Done()
   }

   //  context is done 就直接退出
   select {
   case <-ctxDoneCh:
      return ctx.Err()
   default:
   }

   // 加锁
   b.lock.Lock()

   // 检测是否中断
   if b.round.isBroken {
      b.lock.Unlock()
      return ErrBrokenBarrier
   }

   // waiters数量+1
   b.round.count++

   // 保存局部变量以防止竞争
   waitCh := b.round.waitCh
   brokeCh := b.round.brokeCh
   count := b.round.count

   b.lock.Unlock()

   if count > b.parties {
      panic("CyclicBarrier.Await is called more than count of parties")
   }

   if count < b.parties {
      // wait other parties
      select {
      case <-waitCh:
         return nil
      case <-brokeCh:
         return ErrBrokenBarrier
      case <-ctxDoneCh:
         b.breakBarrier(true)
         return ctx.Err()
      }
   } else {
      // we are last, run the barrier action and reset the barrier
      if b.barrierAction != nil {
         err := b.barrierAction()
         if err != nil {
            b.breakBarrier(true)
            return err
         }
      }
      b.reset(true)
      return nil
   }
}

GetParties返回跨越该障碍所需的参与方数量。

func (b *cyclicBarrier) GetParties() int {
   return b.parties
}

IsBroken查询该屏障是否处于损坏状态。 如果由于ctx.Done()的中断或上次重置导致一方或多方突破该屏障,或由于错误导致屏障操作失败,则返回true;否则错误。

func (b *cyclicBarrier) IsBroken() bool {
   b.lock.RLock()
   defer b.lock.RUnlock()

   return b.round.isBroken
}

Reset将barrier重置为初始状态。

func (b *cyclicBarrier) Reset() {
   b.reset(false)
}
func (b *cyclicBarrier) reset(safe bool) {
   b.lock.Lock()
   defer b.lock.Unlock()

   if safe {
      // broadcast to pass waiting goroutines
      close(b.round.waitCh)

   } else if b.round.count > 0 {
      b.breakBarrier(false)
   }

   // create new round
   b.round = &round{
      waitCh:  make(chan struct{}),
      brokeCh: make(chan struct{}),
   }
}

GetNumberWaiting返回当前在屏障处等待的参与方的数量。

func (b *cyclicBarrier) GetNumberWaiting() int {
   b.lock.RLock()
   defer b.lock.RUnlock()

   return b.round.count
}