开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 25 天,点击查看活动详情
CyclicBarrier
CyclicBarrier 是一种同步机制,可以让多个 goroutine 在某个点上相互等待,直到所有 goroutine 都达到该点后才继续执行。在 Go 语言中可以使用 sync.WaitGroup 实现类似的功能,但 CyclicBarrier 具有更多的特性,如可以设置等待 goroutine 数量、可以设置等待超时时间、可以在所有 goroutine 到达指定点后执行回调函数等。
WaitGroup 更适合用在“一个 goroutine 等待一组 goroutine 到达同一个执行点”的场景中,或者是不需要重用的场景中。下面是一个简单例子:
func main() {
barrier := cyclicbarrier.New(5)
barrier.GetNumberWaiting()
ctx := context.Background()
for i := 0; i < 5; i++ {
go func(id int) {
fmt.Printf("Goroutine %d before wait\n", id)
barrier.Await(ctx)
fmt.Printf("Goroutine %d after wait\n", id)
}(i)
}
// Wait for a while to demonstrate that the barrier is cyclic
time.Sleep(2 * time.Second)
for i := 0; i < 5; i++ {
go func(id int) {
fmt.Printf("Goroutine %d before wait\n", id)
barrier.Await(ctx)
fmt.Printf("Goroutine %d after wait\n", id)
}(i)
}
// Wait for a while to see the results
time.Sleep(2 * time.Second)
}
当5个gorouting都到Await时,再一起进行下面的操作,像栅栏一样,所以也叫循环栅栏。
源码阅读
以"github.com/marusama/cyclicbarrier"为例我们一起看一下他源码是如何实现的:
type cyclicBarrier struct {
parties int
barrierAction func() error
lock sync.RWMutex
round *round
}
type round struct {
count int // 此往返的goroutine计数
waitCh chan struct{} // wait channel for this roundtrip
brokeCh chan struct{} // channel for isBroken broadcast
isBroken bool // is barrier broken
}
- parties:等待者数量
- barrierAction:
- lock:RW锁
- round:包含两个chan:wait和broken,两个状态变量
func New(parties int) CyclicBarrier {
if parties <= 0 {
panic("parties must be positive number")
}
return &cyclicBarrier{
parties: parties,
lock: sync.RWMutex{},
round: &round{
waitCh: make(chan struct{}),
brokeCh: make(chan struct{}),
},
}
}
New:
- 检测parties是否>0
- 创建一个cyclicBarrier
Await等待,直到所有各方都在该屏障上调用Await。
func (b *cyclicBarrier) Await(ctx context.Context) error {
var (
ctxDoneCh <-chan struct{}
)
if ctx != nil {
ctxDoneCh = ctx.Done()
}
// context is done 就直接退出
select {
case <-ctxDoneCh:
return ctx.Err()
default:
}
// 加锁
b.lock.Lock()
// 检测是否中断
if b.round.isBroken {
b.lock.Unlock()
return ErrBrokenBarrier
}
// waiters数量+1
b.round.count++
// 保存局部变量以防止竞争
waitCh := b.round.waitCh
brokeCh := b.round.brokeCh
count := b.round.count
b.lock.Unlock()
if count > b.parties {
panic("CyclicBarrier.Await is called more than count of parties")
}
if count < b.parties {
// wait other parties
select {
case <-waitCh:
return nil
case <-brokeCh:
return ErrBrokenBarrier
case <-ctxDoneCh:
b.breakBarrier(true)
return ctx.Err()
}
} else {
// we are last, run the barrier action and reset the barrier
if b.barrierAction != nil {
err := b.barrierAction()
if err != nil {
b.breakBarrier(true)
return err
}
}
b.reset(true)
return nil
}
}
GetParties返回跨越该障碍所需的参与方数量。
func (b *cyclicBarrier) GetParties() int {
return b.parties
}
IsBroken查询该屏障是否处于损坏状态。 如果由于ctx.Done()的中断或上次重置导致一方或多方突破该屏障,或由于错误导致屏障操作失败,则返回true;否则错误。
func (b *cyclicBarrier) IsBroken() bool {
b.lock.RLock()
defer b.lock.RUnlock()
return b.round.isBroken
}
Reset将barrier重置为初始状态。
func (b *cyclicBarrier) Reset() {
b.reset(false)
}
func (b *cyclicBarrier) reset(safe bool) {
b.lock.Lock()
defer b.lock.Unlock()
if safe {
// broadcast to pass waiting goroutines
close(b.round.waitCh)
} else if b.round.count > 0 {
b.breakBarrier(false)
}
// create new round
b.round = &round{
waitCh: make(chan struct{}),
brokeCh: make(chan struct{}),
}
}
GetNumberWaiting返回当前在屏障处等待的参与方的数量。
func (b *cyclicBarrier) GetNumberWaiting() int {
b.lock.RLock()
defer b.lock.RUnlock()
return b.round.count
}