我们知道线程池的实现原理是非常简单的,主要包括3个部分
- 线程池的创建与销毁
- 线程池中work的管理
- 任务的提交与调度
从以上的图中可以看出workpool的大概实现原理,capacity代表着线程池的最大容量,我们使用一个待缓冲的channel active作为计数器,当active channel 可写的时候,我们就创建一个worker,用户使用schedule提交任务,当,当active channel不可写的时候,pool停止创建pool,当某个运行中worker退出,释放active,pool才可以继续创建worker。我们把自身的任务抽象成一个个task,通过schedule调度进入task channel,已经创建的worker从task channel 读取task 并执行。workpool的工作原理是不是很简单啊 Talk is Cheap,Show me the Code,接下来我们就来实现一个简易版的workpool,来证明上述分析的可行性。
package workpool
import (
"errors"
"fmt"
"sync"
)
const (
defaultCapacity = 10
maxCapacity = 100
)
type Task func()
type Pool struct {
capacity int
active chan struct{}
tasks chan Task
quit chan struct{} // pool 销毁前的退出信号
wg sync.WaitGroup // pool销毁时等待所有worker退出
}
func NewPool(capacity int) *Pool {
if capacity <= 0 {
capacity = defaultCapacity
}
if capacity > maxCapacity {
capacity = maxCapacity
}
fmt.Println("workpool start")
p := &Pool{
capacity: capacity,
active: make(chan struct{}, capacity),
tasks: make(chan Task),
quit: make(chan struct{}),
}
go p.run()
return p
}
func (p *Pool) run() {
idx := 0
for {
select {
// recive a workpool free signal quit
case <-p.quit:
fmt.Println("run method exit")
return
// if active channel is not fulled create a new worker
case p.active <- struct{}{}:
idx++
go p.createWorker(idx)
}
}
}
func (p *Pool) createWorker(idx int) {
p.wg.Add(1)
go func() {
defer func() {
if err := recover(); err != nil {
fmt.Printf("worker[%03d]: recover panic[%s] and exit\n", idx, err)
<-p.active
}
p.wg.Done()
}()
fmt.Printf("worker[%03d]: start\n", idx)
for {
select {
case <-p.quit:
fmt.Printf("worker[%03d]: exit\n", idx)
<-p.active
return
case t := <-p.tasks:
fmt.Printf("worker[%03d]: receive a task\n", idx)
t()
}
}
}()
}
var ErrWorkerPoolFreed = errors.New("workerpool freed")
func (p *Pool) Schedule(t Task) error {
select {
// check pool is freed before schedule task
case <-p.quit:
return ErrWorkerPoolFreed
case p.tasks <- t:
return nil
}
}
func (p *Pool) Free() {
close(p.quit)
p.wg.Wait()
fmt.Printf("workerpool freed\n")
}
测试代码
package workpool
import (
"testing"
"time"
)
func TestWorkpool(t *testing.T) {
pool := NewPool(5)
for i := 0; i < 10; i++ {
err := pool.Schedule(func() {
time.Sleep(time.Second * 3)
})
if err != nil {
println("task: ", i, "err:", err)
}
}
pool.Free()
}
测试运行结果
=== RUN TestWorkpool
workpool start
worker[005]: start
worker[005]: receive a task
worker[001]: start
worker[001]: receive a task
worker[004]: start
worker[002]: start
worker[002]: receive a task
worker[004]: receive a task
worker[003]: start
worker[003]: receive a task
worker[004]: receive a task
worker[005]: receive a task
worker[003]: receive a task
run method exit
worker[002]: receive a task
worker[001]: receive a task
worker[001]: exit
worker[002]: exit
worker[005]: exit
worker[003]: exit
worker[004]: exit
workerpool freed
--- PASS: TestWorkpool (6.00s)
PASS