使用 Golang 模拟一个简单的生产者消费者模型

41 阅读2分钟

self

package main

import (
	"fmt"
	"sync"
	"time"
)

const (
	WORKER_NUMBER = 5
	TASK_NUMBER   = 1000
)

type Task struct {
	taskId int
	work   int
}

var (
	sum int
	wg  sync.WaitGroup
)

func main() {
	// sync channel
	task_ch := make(chan Task)
	res_ch := make(chan int, TASK_NUMBER)

	start := time.Now()

	// consumer
	for i := 1; i <= WORKER_NUMBER; i++ {
		wg.Add(1)
		go func(workerId int) {
			for t := range task_ch {
				fmt.Printf("worker %d is process task %d....\n", workerId, t.taskId)
				res_ch <- t.work * t.work
			}
			wg.Done()
		}(i)
	}

	// producer
	var t Task
	for task := 0; task < TASK_NUMBER; task++ {
		t = Task{
			taskId: task,
			work:   task,
		}
		fmt.Printf("task %d is produced...\n", task)
		task_ch <- t
	}
	close(task_ch)
	wg.Wait()
	close(res_ch)

	for res := range res_ch {
		sum += res
	}
	fmt.Println("producer and consumer demo's result is ", sum)
	fmt.Printf("time use %v", time.Since(start))
}

存在的问题:如果使用 WaitGroup 管理 worker goroutine,我们使用 for 从 result channel 接收 res 以此达到阻塞的目的:但不知道到底哪一个 goroutine 是最后一个 worker goroutine,因此无法正常关闭 res channel,导致最后的 for 循环阻塞,最终导致死锁。

chatgpt 同步通道版本

针对上述问题,将 producer、consumer 全部放到 goroutine 中执行;

除此之外,我们需要调用 wg.Wait() 阻塞等待 worker goroutine 的结束顺势 close 掉 res channel;

但是一旦我们将 wg.Wait() 放到主 goroutine 中,就无法从 res channel 中读取数据,这将导致死锁。

所以我们可以将 wg.Wait() 也放到 goroutine 处理;

在主 goroutine 中,使用 for 从 res channel 中读取 worker 处理后的数据从而到达阻塞的目的。

当所有的 worker 处理完毕后,wg.Wait() 从阻塞中结束,res channel 成功被关闭,主 goroutine 也成功从 for channel 阻塞中退出,继续往下执行,主 goroutine 退出,程序正常结束。

package main

import (
	"fmt"
	"sync"
	"time"
)

const (
	WORKER_NUMBER = 5
	TASK_NUMBER   = 1000
)

type Task struct {
	taskId int
	work   int
}

var (
	sum int
	wg  sync.WaitGroup
)

func worker(task_ch <-chan Task, res_ch chan<- int, workerId int) {
	for t := range task_ch {
		fmt.Printf("worker %d is processing task %d...\n", workerId, t.taskId)
		res_ch <- t.work * t.work
	}
	wg.Done()
}

func main() {
	start := time.Now()
	// Sync channels
	task_ch := make(chan Task)
	res_ch := make(chan int)

	// Start workers
	wg.Add(WORKER_NUMBER)
	for i := 1; i <= WORKER_NUMBER; i++ {
		go worker(task_ch, res_ch, i)
	}

	// Producer
	go func() {
		for task := 0; task < TASK_NUMBER; task++ {
			fmt.Printf("task %d is produced...\n", task)
			task_ch <- Task{taskId: task, work: task}
		}
		close(task_ch)
	}()

	// Wait for workers to finish
	go func() {
		wg.Wait()
		close(res_ch)
	}()

	// Consume results
	for res := range res_ch {
		sum += res
	}

	fmt.Println("producer and consumer demo's result is ", sum)
	fmt.Println("time used:", time.Since(start))
}

chatgpt:使用异步通道实现

package main

import (
    "fmt"
    "sync"
)

const (
    WORKER_NUMBER = 5
    TASK_NUMBER   = 1000
)

type Task struct {
    taskId int
    work   int
}

var (
    sum int
    wg  sync.WaitGroup
)

func worker(task_ch <-chan Task, res_ch chan<- int, workerId int) {
    for t := range task_ch {
        fmt.Printf("worker %d is processing task %d...\n", workerId, t.taskId)
        res_ch <- t.work * t.work
    }
    wg.Done()
}

func main() {
    // Async channels
    task_ch := make(chan Task, 10)
    res_ch := make(chan int, 10)

    // Start workers
    wg.Add(WORKER_NUMBER)
    for i := 1; i <= WORKER_NUMBER; i++ {
        go worker(task_ch, res_ch, i)
    }

    // Producer
    go func() {
        for task := 0; task < TASK_NUMBER; task++ {
            fmt.Printf("task %d is produced...\n", task)
            task_ch <- Task{taskId: task, work: task}
        }
        close(task_ch)
    }()

    // Wait for workers to finish
    go func() {
        wg.Wait()
        close(res_ch)
    }()

    // Consume results
    for res := range res_ch {
        sum += res
    }

    fmt.Println("producer and consumer demo's result is ", sum)
}