动态控制协程个数

59 阅读2分钟

定义配置文件 config.yaml

concurrency:
  max_workers: 3   # 可开启的最大协程数

定义订阅者,用于开启协程

package main

import (
    "fmt"
    "sync"
)

type Subscriber struct {
    ID         int
    createOnce sync.Once
    readOnce   sync.Once
}

func NewSubscriber(Id int, dataCh chan string, close chan bool) (subscriber *Subscriber) {
    subscriber = &Subscriber{
        ID: Id,
    }
    subscriber.createOnce.Do(func() {
        fmt.Println(fmt.Sprintf("【%d】已创建", subscriber.ID))
    })
    subscriber.readOnce.Do(func() {
        go func(sub *Subscriber, dataCh chan string, close chan bool) {
            for {
                select {
                case _ = <-close:
                    fmt.Println(fmt.Sprintf("【%d】已销毁", sub.ID))
                    return
                case data := <-dataCh:
                    fmt.Println(fmt.Sprintf("【%d】消费中:%s", sub.ID, data))
                }
            }
        }(subscriber, dataCh, close)
    })
    return
}

定义 动态消费者管理结构体

// 动态消费者管理结构体
type ConsumerManager struct {
    subs    map[int]*Subscriber // 消费者
    closeCh map[int]chan bool   // 关闭通道
    dataCh  chan string         // 数据通道
    subNum  int                 // 消费者数量
    mu      sync.Mutex          // 互斥锁
}

实现构造方法

func NewConsumerManager(dataCh chan string) (consumerManager *ConsumerManager) {
    return &ConsumerManager{
        subs:    make(map[int]*Subscriber),
        closeCh: make(map[int]chan bool),
        dataCh:  dataCh,
    }
}

实现相应方法

/**
 * @name: 移除消费者
 * @desc:
 * @param {int} subId
 * @return {*}
 */
func (cm *ConsumerManager) RemoveConsumer(subId int) {
    cm.mu.Lock()
    defer cm.mu.Unlock()
    // 对消费者发送停止信号
    closeCh, ok := cm.closeCh[subId]
    if ok {
        closeCh <- true
        close(closeCh)
        delete(cm.closeCh, subId)
    }
    // 移除消费者
    _, ok = cm.subs[subId]
    if ok {
        cm.subs[subId] = nil
        delete(cm.subs, subId)
        cm.subNum--
    }
}

/**
 * @name: 添加消费者
 * @desc:
 * @param {int} subId
 * @return {*}
 */
func (cm *ConsumerManager) AddConsumer(subId int) {
    cm.mu.Lock()
    defer cm.mu.Unlock()
    _, ok := cm.subs[subId]
    if ok {
        return
    }
    closeCh := make(chan bool)
    sub := NewSubscriber(subId, cm.dataCh, closeCh)
    cm.subs[subId] = sub
    cm.closeCh[subId] = closeCh
    cm.subNum++
}

/**
 * @name: 自动调度
 * @desc:
 * @return {*}
 */
func (cm *ConsumerManager) AutoScheduling(num int) {
    if num <= 0 {
        num = 1
    }
    if cm.subNum > num {
        // 关闭多余协程
        for subId := cm.subNum; subId > num; subId-- {
            cm.RemoveConsumer(subId)
        }
    } else if cm.subNum < num {
        // 开启新协程
        for subId := cm.subNum + 1; subId <= num; subId++ {
            cm.AddConsumer(subId)
        }
    }
}

调用

package main

import (
    "fmt"
    "log"
    "time"

    "github.com/fsnotify/fsnotify"
    "github.com/spf13/viper"
)

const (
    limit = 5
)

var (
    config     Config
    jobs       = make(chan int, limit)
    closeChMap = make(map[int]chan bool)
    dataCh     = make(chan string, 100)
)

// Config 是一个示例配置结构体
type Config struct {
    Concurrency struct {
        MaxWorkers int `mapstructure:"max_workers"`
    } `mapstructure:"concurrency"`
}

func worker(workName string, dataCh chan string, closeCh chan bool) {
    for {
        select {
        case _ = <-closeCh:
            fmt.Println(fmt.Sprintf("【%】worker close"), workName)
            return
        case data := <-dataCh:
            fmt.Println(fmt.Sprintf("【%】%"), workName, data)
        }
    }
}

func readConf() {
    viper.SetConfigName("config") // name of config file (without extension)
    viper.SetConfigType("yaml")   // REQUIRED if the config file does not have the extension in the name
    viper.AddConfigPath("./conf") // path to look for the config file in
    err := viper.ReadInConfig()   // Find and read the config file
    if err != nil {               // Handle errors reading the config file
        panic(fmt.Errorf("fatal error config file: %w", err))
    }
    if err := viper.Unmarshal(&config); err != nil {
        log.Fatalf("Unable to decode into struct, %v", err)
    }
    fmt.Printf("Config: %+v\n", config)
    // 监听配置文件变化
    viper.WatchConfig()
    // 设置回调函数
    viper.OnConfigChange(func(e fsnotify.Event) {
        fmt.Println("Config file changed:", e.Name)
        if err := viper.Unmarshal(&config); err != nil {
            log.Fatalf("Unable to decode into struct, %v", err)
        }
        fmt.Printf("Config: %+v\n", config)
        // for _, closeCh := range closeChMap {
        //     closeCh <- true
        // }
        cm.AutoScheduling(config.Concurrency.MaxWorkers)
    })
}

var cm *ConsumerManager

func main() {
    dataCh = make(chan string)
    cm = NewConsumerManager(dataCh)
    cm.AutoScheduling(5)
    readConf()

    go func() {
        for {
            dataCh <- "Hello"
            time.Sleep(time.Second)
        }
    }()
    select {}
}