go实现并行网站递归下载工具

123 阅读7分钟

网页批量下载器小工具功能清单


核心功能

  1. 多URL批量下载

    • 支持输入多个URL列表(文件或命令行参数)。
  2. 并发下载

    • 通过协程(goroutine)实现并发下载任务。
    • 可配置最大并发数(避免服务器封禁)。
  3. 断点续传

    • 支持下载中断后恢复,记录已下载进度。
  4. 错误处理与重试

    • 自动重试失败的下载任务(可配置重试次数)。
  5. 日志与进度显示

    • 实时显示下载进度(单个任务和总进度)。
package main

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"os"
	"os/signal"
	"path"
	"path/filepath"
	"strings"
	"sync"
	"syscall"
	"time"

	"github.com/schollz/progressbar/v3"
	"github.com/spf13/cobra"
	"golang.org/x/net/html"
)

type Config struct {
	Url         []string
	Dir         string
	Concurrency int
	Timeout     int
	Retry       int
	Depth       int  // 递归深度
	DownloadRes bool // 是否下载资源
}

var downloadedURLs = sync.Map{}

func main() {
	cmd := newCmdWebpagedown()
	if err := cmd.Execute(); err != nil {
		fmt.Println(err)
	}
}

func newCmdWebpagedown() *cobra.Command {
	var cfg Config

	//获取当前工作目录
	currentDir, err := os.Getwd()
	if err != nil {
		currentDir = "."
	}

	cmd := &cobra.Command{
		Use:   "webpagedown",
		Short: "Download web pages",
		PreRunE: func(cmd *cobra.Command, args []string) error {
			if len(cfg.Url) == 0 {
				return fmt.Errorf("至少需要一个URL")
			}
			if cfg.Dir == "" {
				cfg.Dir = currentDir
			}
			return nil
		},
		RunE: func(cmd *cobra.Command, args []string) error {
			fmt.Println("URL:", cfg.Url)
			fmt.Println("Dir:", cfg.Dir)
			fmt.Println("Concurrency:", cfg.Concurrency)
			fmt.Println("Timeout:", cfg.Timeout)
			fmt.Println("Retry:", cfg.Retry)
			downloadWebPages(cfg.Url, cfg.Dir, cfg.Concurrency, cfg.Timeout, cfg.Retry, cfg.Depth, cfg.DownloadRes)
			return nil
		},
	}

	cmd.Flags().StringSliceVarP(&cfg.Url, "url", "u", []string{}, `被下载的网页URL(可以多次使用或用","间隔多个URL)`)
	cmd.Flags().StringVarP(&cfg.Dir, "dir", "d", "", "下载的网页保存目录,默认为当前目录")
	cmd.Flags().IntVarP(&cfg.Concurrency, "concurrency", "c", 1, "并发下载数")
	cmd.Flags().IntVarP(&cfg.Timeout, "timeout", "t", 30, "超时时间")
	cmd.Flags().IntVarP(&cfg.Retry, "retry", "r", 3, "重试次数")
	cmd.Flags().IntVarP(&cfg.Depth, "depth", "p", 1, "递归下载深度")
	cmd.Flags().BoolVarP(&cfg.DownloadRes, "resources", "s", false, "是否下载页面资源(图片等)")

	return cmd
}

// 实现批量下载网页的功能
func downloadWebPages(urls []string, dir string, concurrency int, timeout int, retry int, depth int, downloadRes bool) error {
	// 创建一个带取消的上下文
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	// 添加信号处理
	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
	defer signal.Stop(sigChan)

	// 监听中断信号
	go func() {
		<-sigChan
		fmt.Println("\n接收到中断信号,正在保存下载进度...")
		cancel() //立即触发取消操作
	}()

	//判断dir是否存在,不存在则创建
	if _, err := os.Stat(dir); os.IsNotExist(err) {
		os.MkdirAll(dir, os.ModePerm)
	}

	//创建一个http client
	client := &http.Client{
		Timeout: time.Duration(timeout) * time.Second,
	}

	//创建工作池
	var wg sync.WaitGroup
	jobs := make(chan string, concurrency)
	results := make(chan error, concurrency)

	//启动goroutine,从jobs中获取url,下载网页
	for i := 0; i < concurrency; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for url := range jobs {
				select {
				case <-ctx.Done():
					return
				default:
					err := downloadWebPageWithRetry(ctx, client, url, dir, retry, depth, downloadRes, concurrency)
					results <- err
				}
			}
		}()
	}

	//启动goroutine,从results中获取下载结果
	done := make(chan struct{})
	go func() {
		for result := range results {
			if result != nil && result != context.Canceled {
				fmt.Println(result)
			}
		}
		close(done)
	}()

	//将url放入jobs中
urlLoop:
	for _, url := range urls {
		select {
		case <-ctx.Done():
			break urlLoop
		case jobs <- url:
		}
	}

	//关闭jobs
	close(jobs)

	//等待所有goroutine结束
	wg.Wait()

	//关闭results
	close(results)

	//等待打印结果的goroutine结束
	<-done

	return nil
}

// 实现单个网页下载重试的功能
func downloadWebPageWithRetry(ctx context.Context, client *http.Client, url string, dir string, retry int, depth int, downloadRes bool, concurrency int) error {
	var lastErr error
	for i := 0; i < retry; i++ {
		select {
		case <-ctx.Done():
			return context.Canceled
		default:
			err := downloadWebPage(ctx, client, url, dir, depth, downloadRes, concurrency)
			if err == nil || err == context.Canceled {
				return err
			}
			lastErr = err
		}

	}
	return fmt.Errorf("下载网页失败(重试%d次):%s, 错误:%v", retry, url, lastErr)
}

// 实现单个网页下载的功能
func downloadWebPage(ctx context.Context, client *http.Client, urll string, dir string, depth int, downloadRes bool, concurrency int) error {
	// 检查并添加协议前缀
	if !strings.HasPrefix(strings.ToLower(urll), "http://") &&
		!strings.HasPrefix(strings.ToLower(urll), "https://") {
		urll = "http://" + urll
	}

	// 检查是否已下载
	if _, exists := downloadedURLs.Load(urll); exists {
		return nil
	}
	downloadedURLs.Store(urll, true)

	// 创建请求
	req, err := http.NewRequest("GET", urll, nil)
	if err != nil {
		return fmt.Errorf("创建请求失败:%v", err)
	}

	// 设置请求头
	req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")
	req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")
	req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8")
	req.Header.Set("Connection", "keep-alive")

	//解析URL
	parsedUrl, err := url.Parse(urll)
	if err != nil {
		return fmt.Errorf("解析URL失败:%s", urll)
	}

	// 根据URL路径创建目录结构
	urlPath := parsedUrl.Path
	if urlPath == "" || urlPath == "/" {
		urlPath = "/index.html"
	}

	// 构建目标目录路径:基础目录 + 域名 + URL路径的目录部分
	hostname := strings.TrimPrefix(parsedUrl.Hostname(), "www.")
	// 安全处理目录路径
	urlPath = sanitizeFileName(urlPath)
	targetDir := filepath.Join(dir, sanitizeFileName(hostname), filepath.Dir(urlPath))
	// 创建目录
	if err := os.MkdirAll(targetDir, os.ModePerm); err != nil {
		return fmt.Errorf("创建目录失败:%s", targetDir)
	}

	// 获取文件名
	filename := filepath.Base(parsedUrl.Path)
	if filename == "" || filename == "." || filename == "/" || filename == "\\" {
		filename = "index.html"
	}

	// 安全处理文件名
	filename = sanitizeFileName(filename)

	// 构建唯一文件名:域名_文件名
	// hostname := strings.TrimPrefix(parsedUrl.Hostname(), "www.")
	// uniqueFilename := fmt.Sprintf("%s_%s", hostname, filename)

	// 构建完整的文件路径
	fullPath := filepath.Join(targetDir, filename)
	progressFile := fullPath + ".progress"

	// 检查是否有未完成的下载
	progress, err := loadProgress(progressFile)
	if err != nil {
		return fmt.Errorf("加载进度失败:%v", err)
	}

	// 如果文件已存在,检查是否需要继续下载
	if fileExists(fullPath) {
		if progress == nil {
			// 文件存在但没有进度文件,创建新文件名
			for i := 1; ; i++ {
				ext := path.Ext(filename)
				basename := strings.TrimSuffix(filename, ext)
				fullPath = filepath.Join(targetDir, fmt.Sprintf("%s_%d%s", basename, i, ext))
				progressFile = fullPath + ".progress"
				if !fileExists(fullPath) {
					break
				}
			}
		} else {
			// 有进度文件,检查是否支持断点续传
			req.Header.Set("Range", fmt.Sprintf("bytes=%d-", progress.Downloaded))
		}
	}

	// 发送请求
	resp, err := client.Do(req)
	if err != nil {
		return fmt.Errorf("下载失败:%v", err)
	}
	defer resp.Body.Close()

	// 如果需要继续递归,先提取链接
	var resources, links []string
	if depth > 0 {
		// 创建 TeeReader 同时读取和写入
		var buf bytes.Buffer
		teeReader := io.TeeReader(resp.Body, &buf)

		// 提取资源和链接
		resources, links = extractResources(teeReader, parsedUrl)

		// 使用缓冲的内容创建新的 Reader 用于保存文件
		resp.Body = io.NopCloser(&buf)
	}

	// 创建或打开文件
	var file *os.File
	if progress != nil {
		file, err = os.OpenFile(fullPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
	} else {
		file, err = os.Create(fullPath)
		progress = &DownloadProgress{
			URL:           urll,
			LocalPath:     fullPath,
			ContentLength: resp.ContentLength,
			Downloaded:    0,
			StartTime:     time.Now(),
		}
		// 立即保存初始进度
		if err := saveProgress(progress, progressFile); err != nil {
			fmt.Printf("保存初始进度失败:%v\n", err)
		}
	}
	if err != nil {
		return fmt.Errorf("创建文件失败:%s", fullPath)
	}

	defer func() {
		if file != nil {
			file.Sync() // 确保数据写入磁盘
			file.Close()
		}
	}()

	// 使用父 context 创建子context
	downloadCtx, cancel := context.WithCancel(ctx)
	defer cancel()

	// 创建一个等待组来确保信号处理完成
	var signalWg sync.WaitGroup
	signalWg.Add(1)

	// 启动信号处理 goroutine
	go func() {
		defer signalWg.Done()
		<-downloadCtx.Done()
		if progress != nil {
			file.Sync()
			if err := saveProgress(progress, progressFile); err != nil {
				fmt.Printf("保存进度失败:%v\n", err)
			}
		}
	}()

	// 创建带进度的写入器
	progressWriter := &ProgressWriter{
		Writer:   file,
		Progress: progress,
		bar: progressbar.NewOptions64(
			progress.ContentLength,
			progressbar.OptionSetDescription("下载中"),
			progressbar.OptionShowBytes(true),
			progressbar.OptionSetWidth(50),
		),
		ctx: downloadCtx,
	}

	//写入文件
	_, err = io.Copy(progressWriter, resp.Body)
	if err != nil {
		if err == context.Canceled {
			fmt.Println("\n下载已中断")
			// 确保在中断时保存最终进度
			file.Sync()

			// 尝试读取文件大小
			if fi, err := file.Stat(); err == nil {
				progress.Downloaded = fi.Size()
			}

			if err := saveProgress(progress, progressFile); err != nil {
				fmt.Printf("保存最终进度失败:%v\n", err)
			}

			// 再次确保文件已同步
			file.Sync()

			// 等待信号处理完成
			signalWg.Wait()
			fmt.Printf("已保存进度,当前下载大小:%d 字节\n", progress.Downloaded)
			return context.Canceled
		}
		// 只有在非取消的错误情况下才删除文件
		file.Close()
		if err != context.Canceled {
			os.Remove(fullPath)
			os.Remove(progressFile)
		}
		return fmt.Errorf("写入文件失败:%s", fullPath)
	}

	//只有在完全下载完成时才删除进度文件
	defer func() {
		if progress.Downloaded == progress.ContentLength {
			file.Sync()
			os.Remove(progressFile)
		}
	}()

	// 处理提取的资源和链接
	if depth > 0 {
		var wg sync.WaitGroup

		// 下载资源和链接的通道
		downloadChan := make(chan string, len(resources)+len(links))

		for i := 0; i < concurrency; i++ {
			go func() {
				for url := range downloadChan {
					select {
					case <-downloadCtx.Done():
						wg.Done()
						return
					default:
						err := downloadWebPage(downloadCtx, client, url, dir, depth-1, downloadRes, concurrency)
						if err != nil {
							fmt.Printf("下载失败 %s: %v\n", url, err)
						}
						wg.Done()
					}

				}
			}()
		}

		// 添加资源下载任务
		if downloadRes {
			for _, resURL := range resources {
				select {
				case <-downloadCtx.Done():
					return context.Canceled
				default:
					resUrlObj, err := url.Parse(resURL)
					if err == nil && resUrlObj.Host != "" {
						wg.Add(1)
						downloadChan <- resUrlObj.String()
					}
				}

			}
		}

		// 添加链接下载任务
		if downloadRes {
			for _, link := range links {
				select {
				case <-downloadCtx.Done():
					return context.Canceled
				default:
					linkURL, err := url.Parse(link)
					if err == nil && linkURL.Host != "" {
						wg.Add(1)
						downloadChan <- linkURL.String()
					}
				}

			}
		}

		// 关闭下载通道
		close(downloadChan)

		// 使用带超时的等待
		done := make(chan struct{})
		go func() {
			wg.Wait()
			close(done)
		}()

		select {
		case <-downloadCtx.Done():
			return context.Canceled
		case <-done:
			return nil
		}
	}

	return nil
}

// 判断文件是否存在
func fileExists(filepath string) bool {
	_, err := os.Stat(filepath)
	return !os.IsNotExist(err)
}

// 解析HTML并提取资源链接
func extractResources(body io.Reader, baseURL *url.URL) ([]string, []string) {
	var resources, links []string
	tokenizer := html.NewTokenizer(body)

	for {
		tokenType := tokenizer.Next()
		if tokenType == html.ErrorToken {
			break
		}

		token := tokenizer.Token()
		switch token.Data {
		case "img", "script", "link":
			// 提取资源URL
			for _, attr := range token.Attr {
				if attr.Key == "src" || attr.Key == "href" {
					resURL, err := baseURL.Parse(attr.Val)
					if err == nil {
						resources = append(resources, resURL.String())
					}
				}
			}
		case "a":
			// 提取链接URL
			for _, attr := range token.Attr {
				if attr.Key == "href" {
					linkURL, err := baseURL.Parse(attr.Val)
					if err == nil && strings.HasPrefix(linkURL.String(), baseURL.String()) {
						links = append(links, linkURL.String())
					}
				}
			}
		}
	}
	// URL 编码处理
	for i, resURL := range resources {
		if u, err := url.QueryUnescape(resURL); err == nil {
			resources[i] = u
		}
	}

	for i, link := range links {
		if u, err := url.QueryUnescape(link); err == nil {
			links[i] = u
		}
	}
	return resources, links
}

type DownloadProgress struct {
	URL           string    `json:"url"`
	LocalPath     string    `json:"local_path"`
	ContentLength int64     `json:"content_length"`
	Downloaded    int64     `json:"downloaded"`
	StartTime     time.Time `json:"start_time"`
}

// 保存下载进度
func saveProgress(progress *DownloadProgress, progressFile string) error {
	// 确保目录存在
	dir := filepath.Dir(progressFile)
	if err := os.MkdirAll(dir, 0755); err != nil {
		return fmt.Errorf("创建进度文件目录失败:%v", err)
	}

	data, err := json.Marshal(progress)
	if err != nil {
		return err
	}

	// 直接写入进度文件,不使用临时文件
	if err := os.WriteFile(progressFile, data, 0644); err != nil {
		return err
	}

	return nil
}

// 加载下载进度
func loadProgress(progressFile string) (*DownloadProgress, error) {
	data, err := os.ReadFile(progressFile)
	if err != nil {
		if os.IsNotExist(err) {
			return nil, nil
		}
		return nil, err
	}

	var progress DownloadProgress
	if err := json.Unmarshal(data, &progress); err != nil {
		return nil, err
	}
	return &progress, nil
}

// 进度写入器
//
//	type ProgressWriter struct {
//		Writer   io.Writer
//		Progress *DownloadProgress
//		OnWrite  func(*DownloadProgress)
//		ctx      context.Context
//	}
type ProgressWriter struct {
	Writer   io.Writer
	Progress *DownloadProgress
	bar      *progressbar.ProgressBar
	ctx      context.Context
}

func NewProgressWriter(w io.Writer, contentLength int64) *ProgressWriter {
	bar := progressbar.NewOptions64(
		contentLength,
		progressbar.OptionSetDescription("下载中"),
		progressbar.OptionShowBytes(true),
		progressbar.OptionSetWidth(50),
	)
	return &ProgressWriter{
		Writer:   w,
		Progress: &DownloadProgress{ContentLength: contentLength},
		bar:      bar,
	}
}

func (pw *ProgressWriter) Write(p []byte) (n int, err error) {
	// 检查是否需要中断
	select {
	case <-pw.ctx.Done():
		// 写入最后的数据
		n, err = pw.Writer.Write(p)
		if err == nil {
			pw.Progress.Downloaded += int64(n)
			if err = pw.bar.Set64(pw.Progress.Downloaded); err != nil {
				return n, err
			}
		}
		return n, context.Canceled
	default:
		n, err = pw.Writer.Write(p)
		if err == nil {
			pw.Progress.Downloaded += int64(n)
			if err = pw.bar.Set64(pw.Progress.Downloaded); err != nil {
				return n, err
			}
			// // 计算并显示进度
			// if pw.Progress.ContentLength > 0 {
			// 	percent := float64(pw.Progress.Downloaded) / float64(pw.Progress.ContentLength) * 100
			// 	speed := float64(pw.Progress.Downloaded) / time.Since(pw.Progress.StartTime).Seconds() / 1024 // KB/s

			// 	// 强制刷新输出
			// 	fmt.Fprintf(os.Stdout, "\r正在下载 %s... %.2f%% (%.2f KB/s)     ",
			// 		filepath.Base(pw.Progress.LocalPath),
			// 		percent,
			// 		speed)
			// 	os.Stdout.Sync()
			// } else {
			// 	// 未知文件大小,只显示已下载大小和速度
			// 	downloadedMB := float64(pw.Progress.Downloaded) / 1024 / 1024
			// 	speed := float64(pw.Progress.Downloaded) / time.Since(pw.Progress.StartTime).Seconds() / 1024 // KB/s
			// 	fmt.Fprintf(os.Stdout, "\r正在下载 %s... %.2f MB (%.2f KB/s)     ",
			// 		filepath.Base(pw.Progress.LocalPath),
			// 		downloadedMB,
			// 		speed)
			// 	os.Stdout.Sync()
			// }
			// // 如果下载完成,打印换行符
			// if pw.Progress.Downloaded == pw.Progress.ContentLength {
			// 	fmt.Fprintln(os.Stdout)
			// }
			// // 强制刷新标准输出
			// os.Stdout.Sync()
		}
	}
	return
}

// 添加一个新的函数来处理文件名
func sanitizeFileName(filename string) string {
	// 替换 Windows 不允许的字符
	invalid := []string{"<", ">", ":", "\"", "/", "\\", "|", "?", "*", "+"}
	result := filename
	for _, char := range invalid {
		result = strings.ReplaceAll(result, char, "_")
	}
	return result
}

以上代码还有一些小bug,1.ctl+c结束下载速度较慢 2.无法实时显示下载进度。

后续会解决这些bug