实现一个轻量级线程池

224 阅读2分钟

我们知道线程池的实现原理是非常简单的,主要包括3个部分

  1. 线程池的创建与销毁
  2. 线程池中work的管理
  3. 任务的提交与调度

image.png

从以上的图中可以看出workpool的大概实现原理,capacity代表着线程池的最大容量,我们使用一个待缓冲的channel active作为计数器,当active channel 可写的时候,我们就创建一个worker,用户使用schedule提交任务,当,当active channel不可写的时候,pool停止创建pool,当某个运行中worker退出,释放active,pool才可以继续创建worker。我们把自身的任务抽象成一个个task,通过schedule调度进入task channel,已经创建的worker从task channel 读取task 并执行。workpool的工作原理是不是很简单啊 Talk is Cheap,Show me the Code,接下来我们就来实现一个简易版的workpool,来证明上述分析的可行性。

package workpool

import (
 "errors"
 "fmt"
 "sync"
)

const (
 defaultCapacity = 10
 maxCapacity     = 100
)

type Task func()

type Pool struct {
 capacity int
 active   chan struct{}
 tasks    chan Task
 quit     chan struct{}  // pool 销毁前的退出信号
 wg       sync.WaitGroup // pool销毁时等待所有worker退出
}

func NewPool(capacity int) *Pool {

 if capacity <= 0 {
    capacity = defaultCapacity
 }
 if capacity > maxCapacity {
    capacity = maxCapacity
 }

 fmt.Println("workpool start")

 p := &Pool{
    capacity: capacity,
    active:   make(chan struct{}, capacity),
    tasks:    make(chan Task),
    quit:     make(chan struct{}),
 }
 go p.run()

 return p

}

func (p *Pool) run() {
 idx := 0
 for {
    select {
    // recive a workpool free signal   quit
    case <-p.quit:
       fmt.Println("run method exit")
       return
    // if active channel is not fulled  create a new worker
    case p.active <- struct{}{}:
       idx++
       go p.createWorker(idx)
    }

 }
}

func (p *Pool) createWorker(idx int) {
 p.wg.Add(1)

 go func() {

    defer func() {
       if err := recover(); err != nil {
          fmt.Printf("worker[%03d]: recover panic[%s] and exit\n", idx, err)
          <-p.active
       }
       p.wg.Done()
    }()
    fmt.Printf("worker[%03d]: start\n", idx)

    for {
       select {
       case <-p.quit:
          fmt.Printf("worker[%03d]: exit\n", idx)
          <-p.active
          return

       case t := <-p.tasks:
          fmt.Printf("worker[%03d]: receive a task\n", idx)
          t()
       }

    }

 }()
}

var ErrWorkerPoolFreed = errors.New("workerpool freed")

func (p *Pool) Schedule(t Task) error {

 select {
 // check pool is freed before schedule task
 case <-p.quit:
    return ErrWorkerPoolFreed
 case p.tasks <- t:
    return nil
 }
}

func (p *Pool) Free() {
 close(p.quit)
 p.wg.Wait()
 fmt.Printf("workerpool freed\n")
}

测试代码

package workpool

import (
   "testing"
   "time"
)

func TestWorkpool(t *testing.T) {

   pool := NewPool(5)

   for i := 0; i < 10; i++ {
      err := pool.Schedule(func() {
         time.Sleep(time.Second * 3)
      })
      if err != nil {
         println("task: ", i, "err:", err)
      }
   }

   pool.Free()

}

测试运行结果

=== RUN   TestWorkpool
workpool start
worker[005]: start
worker[005]: receive a task
worker[001]: start
worker[001]: receive a task
worker[004]: start
worker[002]: start
worker[002]: receive a task
worker[004]: receive a task
worker[003]: start
worker[003]: receive a task
worker[004]: receive a task
worker[005]: receive a task
worker[003]: receive a task
run method exit
worker[002]: receive a task
worker[001]: receive a task
worker[001]: exit
worker[002]: exit
worker[005]: exit
worker[003]: exit
worker[004]: exit
workerpool freed
--- PASS: TestWorkpool (6.00s)
PASS