[golang]使用协程池处理http下载图片

33 阅读1分钟

实现思路

  • 把图片地址导入到csv文件中
  • 开启大小为8的协程池
  • http请求注意超时控制和重试

package main

import (
    "encoding/csv"
    "flag"
    "fmt"
    "io"
    "log"
    "net/http"
    "os"
    "strings"
    "sync"
    "time"
)

// Task 任务结构体  
type Task struct {
    ID string
    Job func()
}

// Pool 协程池结构体  
type Pool struct {
    taskQueue chan Task
    wg sync.WaitGroup
}

// NewPool 创建协程池  
func NewPool(numWorkers int) * Pool {
    p: = & Pool {
        taskQueue: make(chan Task),
    }

        p.wg.Add(numWorkers)
    for i: = 0;i < numWorkers;i++{
        go p.worker()
    }

    return p
}

// AddTask 添加任务到协程池  
func(p * Pool) AddTask(task Task) {
    p.taskQueue < -task
}

// 工作协程  
func(p * Pool) worker() {
    for task: = range p.taskQueue {
        task.Job()
        fmt.Printf("task worker %v finished", task.ID)
    }
    p.wg.Done()
}

// Wait 等待所有任务完成  
func(p * Pool) Wait() {
    close(p.taskQueue)
    p.wg.Wait()
}

var csv2, path string

func main() {
    //参数解析  
    flag.StringVar( & csv2, "csv", "", "csv文件的绝对路径")
    flag.StringVar( & path, "path", ".", "保存图片的文件夹")
    flag.Parse()

    // 打开CSV文件  
    file, err: = os.Open(csv2)
    if err != nil {
        fmt.Println("无法打开CSV文件:", err)
        return
    }
    defer file.Close()

    // 创建一个CSV阅读器  
    reader: = csv.NewReader(file)

    // 读取所有行  
    lines, err: = reader.ReadAll()
    if err != nil {
        fmt.Println("读取CSV文件失败:", err)
        return
    }

    // 创建一个协程池,设置工作协程数为8  
    pool: = NewPool(8)

    maxRetries: = 3
    timeout: = 20 * time.Second

    //遍历csv文件  
    for _, line: = range lines {
        taskID: = line[0]
        task: = Task {
            ID: taskID,
            Job: func() {
                arr: = strings.Split(taskID, "/")
                fileName: = arr[len(arr) - 1]
                filePath: = path + "/" + fileName
                err: = downloadImageWithRetry(taskID, filePath, maxRetries, timeout)
                if err != nil {
                    log.Println(`get file err: `, err, taskID)
                }
            },
        }
        pool.AddTask(task)
    }

    // 等待所有任务完成  
    pool.Wait()
}

func downloadImageWithRetry(url, outputPath string, maxRetries int, timeout time.Duration) error {
    var err error
    for retries: = 0;
    retries <= maxRetries;
    retries++{
        err = downloadImage(url, outputPath, timeout)
        if err == nil {
            return nil
        }
        fmt.Printf("Download attempt %d failed: %s\n", retries + 1, err)
    }
    return err
}

func downloadImage(url, outputPath string, timeout time.Duration) error {
    client: = http.Client {
        Timeout: timeout,
    }
    response,
    err: = client.Get(url)
    if err != nil {
        return err
    }
    defer response.Body.Close()

    file,
    err: = os.Create(outputPath)
    if err != nil {
        return err
    }
    defer file.Close()

    _,
    err = io.Copy(file, response.Body)
    if err != nil {
        return err
    }

    return nil
}