go并发网页下载工具(已修复所有bug)

151 阅读5分钟

网页批量下载器小工具功能清单


核心功能

  1. 多URL批量下载

    • 支持输入多个URL列表(文件或命令行参数)。
  2. 并发下载

    • 通过协程(goroutine)实现并发下载任务。
    • 可配置最大并发数。
  3. 断点续传

    • 支持下载中断后恢复
  4. 错误处理与重试

    • 自动重试失败的下载任务(可配置重试次数)。
  5. 日志与进度显示

    • 实时显示下载进度(单个任务和总进度)。
package main

import (
   "crypto/md5"
   "encoding/hex"
   "fmt"
   "io"
   "net/http"
   "net/url"
   "os"
   "path/filepath"
   "regexp"
   "strings"
   "sync"
   "sync/atomic"
   "time"

   "github.com/PuerkitoBio/goquery"
   "github.com/cheggaaa/pb/v3"
   "github.com/spf13/cobra"
)

type DownloadTask struct {
   URL           string
   TargetPath    string
   CurrentSize   int64
   TotalSize     int64
   Completed     bool
   Retries       int
   MaxRetries    int
   Bar           *pb.ProgressBar
   DownloadedURL *sync.Map
}

var activeTaskCount int32 // 原子计数器
var taskMutex sync.Mutex  // 保护计数器的互斥锁

func main() {
   var (
   	concurrency int
   	maxDepth    int
   	outputDir   string
   	maxRetries  int
   	urls        []string
   	resources   bool
   )

   rootCmd := &cobra.Command{
   	Use:   "downloader",
   	Short: "A concurrent website downloader",
   	PreRun: func(cmd *cobra.Command, args []string) {
   		if len(urls) == 0 {
   			fmt.Println("请提供至少一个URL")
   			os.Exit(1)
   		}
   	},
   	Run: func(cmd *cobra.Command, args []string) {
   		// 创建输出目录
   		if err := os.MkdirAll(outputDir, 0755); err != nil {
   			fmt.Printf("创建输出目录失败: %v\n", err)
   			return
   		}

   		// 初始化下载任务
   		tasks := make([]*DownloadTask, 0, len(urls))
   		var downloadedURLs sync.Map

   		for _, u := range urls {
   			parsedURL, err := url.Parse(u)
   			if err != nil {
   				fmt.Printf("无效的URL %s: %v\n", u, err)
   				continue
   			}

   			// 为每个URL创建一个子目录
   			urlDir := filepath.Join(outputDir, parsedURL.Hostname())
   			if err := os.MkdirAll(urlDir, 0755); err != nil {
   				fmt.Printf("创建URL目录失败 %s: %v\n", urlDir, err)
   				continue
   			}

   			// 确定目标文件路径
   			fileName := filepath.Base(parsedURL.Path)
   			if fileName == "" || fileName == "." || fileName == "/" || fileName == "//" {
   				fileName = "index.html"
   			}

   			// 初始化下载任务
   			task := &DownloadTask{
   				URL:           u,
   				TargetPath:    filepath.Join(urlDir, fileName),
   				MaxRetries:    maxRetries,
   				DownloadedURL: &downloadedURLs,
   			}
   			tasks = append(tasks, task)
   		}

   		// 创建工作池
   		var wg sync.WaitGroup
   		taskCh := make(chan *DownloadTask, len(tasks)*10) // 增大缓冲区以容纳资源下载任务

   		// 在启动工作协程前初始化
   		activeTaskCount = int32(len(tasks))

   		// 启动工作协程
   		for i := 0; i < concurrency; i++ {
   			wg.Add(1)
   			go func() {
   				defer wg.Done()
   				for task := range taskCh {
   					downloadWithRetry(task, maxDepth, resources, taskCh, &activeTaskCount, &taskMutex)
   				}
   			}()
   		}

   		// 发送任务到通道
   		for _, task := range tasks {
   			taskCh <- task
   		}

   		// 等待所有下载完成
   		wg.Wait()

   		// 等待通道关闭
   		for range taskCh {
   		}

   		fmt.Println("所有下载任务已完成")
   	},
   }

   rootCmd.Flags().IntVarP(&concurrency, "concurrency", "c", 5, "并发下载数")
   rootCmd.Flags().IntVarP(&maxDepth, "depth", "d", 1, "递归下载深度")
   rootCmd.Flags().StringVarP(&outputDir, "output", "o", "downloads", "输出目录")
   rootCmd.Flags().IntVarP(&maxRetries, "retries", "r", 3, "下载失败重试次数")
   rootCmd.Flags().StringSliceVarP(&urls, "urls", "u", []string{}, "要下载的URL列表")
   rootCmd.Flags().BoolVarP(&resources, "resources", "s", true, "是否下载资源文件(图片、CSS、JS等)")

   if err := rootCmd.Execute(); err != nil {
   	fmt.Println(err)
   	os.Exit(1)
   }
}

func downloadWithRetry(task *DownloadTask, maxDepth int, downloadResources bool, taskCh chan<- *DownloadTask, activeCount *int32, mutex *sync.Mutex) {
   defer func() {
   	mutex.Lock()
   	atomic.AddInt32(activeCount, -1)
   	// 如果没有活跃任务且通道为空,关闭通道
   	if atomic.LoadInt32(activeCount) == 0 && len(taskCh) == 0 {
   		close(taskCh)
   	}
   	mutex.Unlock()
   }()
   for task.Retries <= task.MaxRetries {
   	err := downloadFile(task)
   	if err == nil {
   		// 下载成功
   		fmt.Printf("成功下载: %s\n", task.URL)

   		// 如果是HTML文件且需要递归下载
   		if isHTMLFile(task.TargetPath) && maxDepth > 0 {
   			// 提取链接和资源
   			links, resources, err := extractLinksAndResources(task.TargetPath, task.URL)
   			if err != nil {
   				fmt.Printf("提取链接失败 %s: %v\n", task.URL, err)
   			} else {
   				// 递归下载链接
   				for _, link := range links {
   					// 检查是否已下载过该URL
   					if _, loaded := task.DownloadedURL.LoadOrStore(link, true); loaded {
   						continue
   					}

   					// 创建新的下载任务
   					parsedURL, _ := url.Parse(link)
   					urlDir := filepath.Join(filepath.Dir(task.TargetPath), getPathFromURL(parsedURL))
   					fileName := filepath.Base(parsedURL.Path)
   					if fileName == "" || fileName == "." || fileName == "/" || fileName == "//" {
   						fileName = "index.html"
   					}

   					newTask := &DownloadTask{
   						URL:           link,
   						TargetPath:    filepath.Join(urlDir, fileName),
   						MaxRetries:    task.MaxRetries,
   						DownloadedURL: task.DownloadedURL,
   					}

   					// 发送到任务通道
   					mutex.Lock()
   					atomic.AddInt32(activeCount, 1)
   					mutex.Unlock()
   					taskCh <- newTask
   				}

   				// 下载资源文件
   				if downloadResources {
   					for _, res := range resources {
   						// 检查是否已下载过该资源
   						if _, loaded := task.DownloadedURL.LoadOrStore(res, true); loaded {
   							continue
   						}
   						// 创建资源下载任务
   						parsedURL, _ := url.Parse(res)
   						resDir := filepath.Join(filepath.Dir(task.TargetPath), getPathFromURL(parsedURL))
   						fileName := filepath.Base(parsedURL.Path)
   						if fileName == "" || fileName == "." {
   							fileName = "resource"
   						}

   						resTask := &DownloadTask{
   							URL:           res,
   							TargetPath:    filepath.Join(resDir, fileName),
   							MaxRetries:    task.MaxRetries,
   							DownloadedURL: task.DownloadedURL,
   						}

   						// 发送到任务通道
   						mutex.Lock()
   						atomic.AddInt32(activeCount, 1)
   						mutex.Unlock()
   						taskCh <- resTask
   					}
   				}
   			}
   		}
   		return
   	}

   	// 下载失败,重试
   	task.Retries++
   	fmt.Printf("下载失败 %s: %v (重试 %d/%d)\n", task.URL, err, task.Retries, task.MaxRetries)
   	time.Sleep(time.Second * 2) // 重试前等待
   }
   fmt.Printf("下载失败,已达到最大重试次数: %s\n", task.URL)
}

func downloadFile(task *DownloadTask) error {
   // 检查文件是否已存在,获取已下载的大小
   fi, err := os.Stat(task.TargetPath)
   if err == nil {
   	task.CurrentSize = fi.Size()
   } else if !os.IsNotExist(err) {
   	return err
   }

   // 创建请求
   req, err := http.NewRequest("GET", task.URL, nil)
   if err != nil {
   	return err
   }

   // 设置User-Agent
   req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")

   // 如果文件已存在,设置Range头部实现断点续传
   if task.CurrentSize > 0 {
   	req.Header.Set("Range", fmt.Sprintf("bytes=%d-", task.CurrentSize))
   }

   // 发送请求
   resp, err := http.DefaultClient.Do(req)
   if err != nil {
   	return err
   }
   defer resp.Body.Close()

   // 检查响应状态
   if resp.StatusCode >= 400 {
   	return fmt.Errorf("服务器返回错误: %s", resp.Status)
   }

   // 获取文件总大小
   if task.TotalSize == 0 {
   	if resp.ContentLength > 0 {
   		task.TotalSize = resp.ContentLength + task.CurrentSize
   	} else {
   		task.TotalSize = -1 // 未知大小
   	}
   }

   // 确保目录存在
   if err := os.MkdirAll(filepath.Dir(task.TargetPath), 0755); err != nil {
   	return err
   }

   // 创建或打开文件
   var file *os.File
   if task.CurrentSize > 0 {
   	file, err = os.OpenFile(task.TargetPath, os.O_APPEND|os.O_WRONLY, 0644)
   } else {
   	file, err = os.Create(task.TargetPath)
   }
   if err != nil {
   	return err
   }
   defer file.Close()

   // 创建进度条
   if task.Bar == nil {
   	if task.TotalSize > 0 {
   		task.Bar = pb.Full.Start64(task.TotalSize)
   	} else {
   		task.Bar = pb.New(-1)
   		task.Bar.Start()
   	}
   	task.Bar.Set("prefix", filepath.Base(task.URL)+" ")
   	task.Bar.SetCurrent(task.CurrentSize)
   }

   // 创建代理读取器来更新进度条
   reader := &ProgressReader{
   	Reader: resp.Body,
   	Bar:    task.Bar,
   }

   // 写入文件
   written, err := io.Copy(file, reader)
   if err != nil {
   	return err
   }

   // 更新已下载大小
   task.CurrentSize += written

   // 检查是否下载完成
   if task.TotalSize > 0 && task.CurrentSize >= task.TotalSize {
   	task.Completed = true
   	task.Bar.Finish()
   }

   return nil
}

// ProgressReader 用于更新进度条
type ProgressReader struct {
   Reader io.Reader
   Bar    *pb.ProgressBar
}

func (pr *ProgressReader) Read(p []byte) (n int, err error) {
   n, err = pr.Reader.Read(p)
   if n > 0 {
   	pr.Bar.Add(n)
   }
   return
}

// 从HTML文件中提取链接和资源
func extractLinksAndResources(filePath, baseURL string) ([]string, []string, error) {
   // 读取文件内容
   file, err := os.Open(filePath)
   if err != nil {
   	return nil, nil, err
   }
   defer file.Close()

   // 使用goquery解析HTML
   doc, err := goquery.NewDocumentFromReader(file)
   if err != nil {
   	return nil, nil, err
   }

   baseURLObj, err := url.Parse(baseURL)
   if err != nil {
   	return nil, nil, err
   }

   links := make([]string, 0)
   resources := make([]string, 0)

   // 提取链接 (a标签)
   doc.Find("a").Each(func(i int, s *goquery.Selection) {
   	href, exists := s.Attr("href")
   	if exists {
   		absURL := resolveURL(baseURLObj, href)
   		if absURL != "" && isSameDomain(baseURLObj, absURL) {
   			links = append(links, absURL)
   		}
   	}
   })

   // 提取图片 (img标签)
   doc.Find("img").Each(func(i int, s *goquery.Selection) {
   	src, exists := s.Attr("src")
   	if exists {
   		absURL := resolveURL(baseURLObj, src)
   		if absURL != "" {
   			resources = append(resources, absURL)
   		}
   	}
   })

   // 提取CSS (link标签)
   doc.Find("link").Each(func(i int, s *goquery.Selection) {
   	rel, _ := s.Attr("rel")
   	href, exists := s.Attr("href")
   	if exists && strings.ToLower(rel) == "stylesheet" {
   		absURL := resolveURL(baseURLObj, href)
   		if absURL != "" {
   			resources = append(resources, absURL)
   		}
   	}
   })

   // 提取JavaScript (script标签)
   doc.Find("script").Each(func(i int, s *goquery.Selection) {
   	src, exists := s.Attr("src")
   	if exists {
   		absURL := resolveURL(baseURLObj, src)
   		if absURL != "" {
   			resources = append(resources, absURL)
   		}
   	}
   })

   // 提取视频和音频 (video, audio标签)
   doc.Find("video source, audio source").Each(func(i int, s *goquery.Selection) {
   	src, exists := s.Attr("src")
   	if exists {
   		absURL := resolveURL(baseURLObj, src)
   		if absURL != "" {
   			resources = append(resources, absURL)
   		}
   	}
   })

   // 提取CSS中的URL (内联样式和样式表)
   doc.Find("style").Each(func(i int, s *goquery.Selection) {
   	cssText := s.Text()
   	cssURLs := extractCSSURLs(cssText, baseURLObj)
   	resources = append(resources, cssURLs...)
   })

   // 去重
   links = removeDuplicates(links)
   resources = removeDuplicates(resources)

   return links, resources, nil
}

// 从CSS文本中提取URL
func extractCSSURLs(cssText string, baseURL *url.URL) []string {
   urls := make([]string, 0)
   re := regexp.MustCompile(`url\(['"]?([^'")]+)['"]?\)`)
   matches := re.FindAllStringSubmatch(cssText, -1)

   for _, match := range matches {
   	if len(match) > 1 {
   		absURL := resolveURL(baseURL, match[1])
   		if absURL != "" {
   			urls = append(urls, absURL)
   		}
   	}
   }

   return urls
}

// 解析相对URL为绝对URL
func resolveURL(base *url.URL, href string) string {
   // 忽略锚点链接和javascript
   if strings.HasPrefix(href, "#") || strings.HasPrefix(href, "javascript:") {
   	return ""
   }

   relURL, err := url.Parse(href)
   if err != nil {
   	return ""
   }

   absURL := base.ResolveReference(relURL)
   return absURL.String()
}

// 检查URL是否属于同一域名
func isSameDomain(baseURL *url.URL, urlStr string) bool {
   u, err := url.Parse(urlStr)
   if err != nil {
   	return false
   }
   return u.Hostname() == baseURL.Hostname()
}

// 去除重复项
func removeDuplicates(urls []string) []string {
   seen := make(map[string]bool)
   result := make([]string, 0, len(urls))

   for _, u := range urls {
   	// 计算URL的MD5作为唯一标识
   	hasher := md5.New()
   	hasher.Write([]byte(u))
   	urlHash := hex.EncodeToString(hasher.Sum(nil))

   	if !seen[urlHash] {
   		seen[urlHash] = true
   		result = append(result, u)
   	}
   }

   return result
}

// 检查文件是否为HTML文件
func isHTMLFile(filePath string) bool {
   ext := strings.ToLower(filepath.Ext(filePath))
   return ext == ".html" || ext == ".htm" || ext == ""
}

// 从URL获取相对路径
func getPathFromURL(parsedURL *url.URL) string {
   path := parsedURL.Path
   if path == "" || path == "/" || path == "//" {
   	return ""
   }

   // 移除开头的斜杠
   path = strings.TrimPrefix(path, "/")

   // 获取目录部分
   dir := filepath.Dir(path)
   if dir == "." {
   	return ""
   }

   return dir
}

此版本下载工具已经修复所有bug