手把手教你写一个golang协程池

800 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

go语言的协程是在是太方便了,一个go就能搞定一切,在简单场景下,是真香! 但是go方法是不是有点java Runnable的味道,如果要知道goroutine的执行状态,还得自己写chan来同步信息。而且协程开多了也会造成内存泄漏和性能问题。

本文将将教你写一个无限队列的golang协程池!

定义协程池

type GorunPool struct {
    Size int
    Ticket int
    Stoped bool
    ResChan chan *GorunTask
    lock sync.RWMutex
    Tasks []*GorunTask
    Once sync.Once
}

结构体说明

  1. Size是池子的大小
  2. Ticket是已经分配出去的数量
  3. Stoped是停止状态,用于整体退出
  4. ResChan是执行完成回传的chan
  5. lock是并发锁
  6. Once是初始化方法

定义协程执行对象

type GorunTask struct {
    Id int64
    Name string
    Status string
    Run func() error
    Callback func()
    Err error
}
  1. Id 任务ID
  2. Name自定义的任务名
  3. Status 任务状态(inited,waitting,running,execed)
  4. Run 是执行函数
  5. Callback 是回调函数,goroutine执行完成后回调
  6. Err 执行异常信息

初始化函数

func NewGorunPool(size int) *GorunPool {
    pool := GorunPool{
        Size: size,
        Ticket: 0,
        Stoped: false,
        ResChan: make(chan *GorunTask),
        Tasks: []*GorunTask{},
    }
    go pool.Once.Do(func() {
        for {
            if pool.Stoped {
                break
            }
            tk := <-pool.ResChan
            if tk.Callback != nil {
                go tk.Callback()
            }
            pool.lock.Lock()
            pool.Ticket--
            pool.lock.Unlock()
        }
    })
    go pool.execqueue()
    return &pool
}

func NewGorunTaskWithName(name string, run func() error, callback func()) *GorunTask {
    task := NewGorunTask(run, callback)
    task.Name = name
    return task
}

func NewGorunTask(run func() error, callback func()) *GorunTask {
    t0 := time.Now()
    task := GorunTask{
        Id: t0.UnixNano(),
        Run: run,
        Callback: callback,
        Status: "inited",
    }
    return &task
}

协程池执行方案

func (pool *GorunPool) Execute(task *GorunTask) {
    //1.判断当前pool的状态
    pool.lock.Lock()
    if pool.Ticket >= pool.Size { //队列等待
        task.Status = "waitting"
        pool.Tasks = append(pool.Tasks, task)
    } else {
        pool.call(task)
    }
    pool.lock.Unlock()
}

func (pool *GorunPool) call(task *GorunTask) {
    pool.Ticket++
    go func() {
        task.Status = "running"
        task.Err = task.Run()
        task.Status = "exected!"
        pool.ResChan <- task
    }()
}

func (pool *GorunPool) execqueue() {
    for {
        if pool.Stoped {
            break
        }
        if len(pool.Tasks) > 0 && pool.Size-pool.Ticket > 0 {
            pool.lock.Lock()
            task := pool.Tasks[0]
            pool.Tasks = pool.Tasks[1:len(pool.Tasks)]
            pool.lock.Unlock()
            pool.call(task)
        } else {
            time.Sleep(100 * time.Microsecond)
        }
    }
}
  1. func (pool *GorunPool) Execute(task *GorunTask) 执行任务的核心入口
  2. func (pool *GorunPool) call(task *GorunTask) 内部方法执行任务
  3. func (pool *GorunPool) execqueue() 执行等待队列

完整代码

package utils

import (
    "sync"
    "time"
)
/**
* 一个固定大小的goroutine 协程池
* 拦截和排队过量的并发协程数量
*/
type GorunTask struct {
    Id int64
    Name string
    Status string
    Run func() error
    Callback func()
    Err error
}

func NewGorunTaskWithName(name string, run func() error, callback func()) *GorunTask {
    task := NewGorunTask(run, callback)
    task.Name = name
    return task
}

func NewGorunTask(run func() error, callback func()) *GorunTask {
    t0 := time.Now()
    task := GorunTask{
        Id: t0.UnixNano(),
        Run: run,
        Callback: callback,
        Status: "inited",
    }
    return &task
}

type GorunPool struct {
    Size int
    Ticket int
    Stoped bool
    ResChan chan *GorunTask
    lock sync.RWMutex
    Tasks []*GorunTask
    Once sync.Once
}

func NewGorunPool(size int) *GorunPool {
    pool := GorunPool{
        Size: size,
        Ticket: 0,
        Stoped: false,
        ResChan: make(chan *GorunTask),
        Tasks: []*GorunTask{},
    }
    go pool.Once.Do(func() {
        for {
            if pool.Stoped {
                break
            }
            tk := <-pool.ResChan
            if tk.Callback != nil {
                go tk.Callback()
            }
            pool.lock.Lock()
            pool.Ticket--
            pool.lock.Unlock()
        }
    })
    go pool.execqueue()
    return &pool
}


func (pool *GorunPool) call(task *GorunTask) {
    pool.Ticket++
    go func() {
        task.Status = "running"
        task.Err = task.Run()
        task.Status = "exected!"
        pool.ResChan <- task
    }()
}

func (pool *GorunPool) execqueue() {
    for {
        if pool.Stoped {
            break
        }
        if len(pool.Tasks) > 0 && pool.Size-pool.Ticket > 0 {
            pool.lock.Lock()
            task := pool.Tasks[0]
            pool.Tasks = pool.Tasks[1:len(pool.Tasks)]
            pool.lock.Unlock()
            pool.call(task)
        } else {
            time.Sleep(100 * time.Microsecond)
        }
    }
}

func (pool *GorunPool) Execute(task *GorunTask) {
    //1.判断当前pool的状态
    pool.lock.Lock()
    if pool.Ticket >= pool.Size { //队列等待
        task.Status = "waitting"
        pool.Tasks = append(pool.Tasks, task)
    } else {
        pool.call(task)
    }
    pool.lock.Unlock()
}

测试代码

func TestPool(t *testing.T) {
    pool := NewGorunPool(5)
    for i := 0; i < 10000; i++ {
        c := i
        job := func() error {
            time.Sleep(1 * time.Second)
            t.Error("do thread!", c)
            return nil
        }
        task := NewGorunTask(job, nil)
        pool.Execute(task)
    }
    time.Sleep(3 * time.Second)
}

测试方案是初始化一个并发5的协程池,每个任务内部休眠1秒,主线程等待3秒 在这3秒后应该打印出执行的任务是15个,测试结果如下:

Jietu20220224-144206.jpg