网页批量下载器小工具功能清单
核心功能
-
多URL批量下载
- 支持输入多个URL列表(文件或命令行参数)。
-
并发下载
- 通过协程(goroutine)实现并发下载任务。
- 可配置最大并发数(避免服务器封禁)。
-
断点续传
- 支持下载中断后恢复,记录已下载进度。
-
错误处理与重试
- 自动重试失败的下载任务(可配置重试次数)。
-
日志与进度显示
- 实时显示下载进度(单个任务和总进度)。
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/signal"
"path"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"github.com/schollz/progressbar/v3"
"github.com/spf13/cobra"
"golang.org/x/net/html"
)
type Config struct {
Url []string
Dir string
Concurrency int
Timeout int
Retry int
Depth int // 递归深度
DownloadRes bool // 是否下载资源
}
var downloadedURLs = sync.Map{}
func main() {
cmd := newCmdWebpagedown()
if err := cmd.Execute(); err != nil {
fmt.Println(err)
}
}
func newCmdWebpagedown() *cobra.Command {
var cfg Config
//获取当前工作目录
currentDir, err := os.Getwd()
if err != nil {
currentDir = "."
}
cmd := &cobra.Command{
Use: "webpagedown",
Short: "Download web pages",
PreRunE: func(cmd *cobra.Command, args []string) error {
if len(cfg.Url) == 0 {
return fmt.Errorf("至少需要一个URL")
}
if cfg.Dir == "" {
cfg.Dir = currentDir
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("URL:", cfg.Url)
fmt.Println("Dir:", cfg.Dir)
fmt.Println("Concurrency:", cfg.Concurrency)
fmt.Println("Timeout:", cfg.Timeout)
fmt.Println("Retry:", cfg.Retry)
downloadWebPages(cfg.Url, cfg.Dir, cfg.Concurrency, cfg.Timeout, cfg.Retry, cfg.Depth, cfg.DownloadRes)
return nil
},
}
cmd.Flags().StringSliceVarP(&cfg.Url, "url", "u", []string{}, `被下载的网页URL(可以多次使用或用","间隔多个URL)`)
cmd.Flags().StringVarP(&cfg.Dir, "dir", "d", "", "下载的网页保存目录,默认为当前目录")
cmd.Flags().IntVarP(&cfg.Concurrency, "concurrency", "c", 1, "并发下载数")
cmd.Flags().IntVarP(&cfg.Timeout, "timeout", "t", 30, "超时时间")
cmd.Flags().IntVarP(&cfg.Retry, "retry", "r", 3, "重试次数")
cmd.Flags().IntVarP(&cfg.Depth, "depth", "p", 1, "递归下载深度")
cmd.Flags().BoolVarP(&cfg.DownloadRes, "resources", "s", false, "是否下载页面资源(图片等)")
return cmd
}
// 实现批量下载网页的功能
func downloadWebPages(urls []string, dir string, concurrency int, timeout int, retry int, depth int, downloadRes bool) error {
// 创建一个带取消的上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 添加信号处理
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(sigChan)
// 监听中断信号
go func() {
<-sigChan
fmt.Println("\n接收到中断信号,正在保存下载进度...")
cancel() //立即触发取消操作
}()
//判断dir是否存在,不存在则创建
if _, err := os.Stat(dir); os.IsNotExist(err) {
os.MkdirAll(dir, os.ModePerm)
}
//创建一个http client
client := &http.Client{
Timeout: time.Duration(timeout) * time.Second,
}
//创建工作池
var wg sync.WaitGroup
jobs := make(chan string, concurrency)
results := make(chan error, concurrency)
//启动goroutine,从jobs中获取url,下载网页
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for url := range jobs {
select {
case <-ctx.Done():
return
default:
err := downloadWebPageWithRetry(ctx, client, url, dir, retry, depth, downloadRes, concurrency)
results <- err
}
}
}()
}
//启动goroutine,从results中获取下载结果
done := make(chan struct{})
go func() {
for result := range results {
if result != nil && result != context.Canceled {
fmt.Println(result)
}
}
close(done)
}()
//将url放入jobs中
urlLoop:
for _, url := range urls {
select {
case <-ctx.Done():
break urlLoop
case jobs <- url:
}
}
//关闭jobs
close(jobs)
//等待所有goroutine结束
wg.Wait()
//关闭results
close(results)
//等待打印结果的goroutine结束
<-done
return nil
}
// 实现单个网页下载重试的功能
func downloadWebPageWithRetry(ctx context.Context, client *http.Client, url string, dir string, retry int, depth int, downloadRes bool, concurrency int) error {
var lastErr error
for i := 0; i < retry; i++ {
select {
case <-ctx.Done():
return context.Canceled
default:
err := downloadWebPage(ctx, client, url, dir, depth, downloadRes, concurrency)
if err == nil || err == context.Canceled {
return err
}
lastErr = err
}
}
return fmt.Errorf("下载网页失败(重试%d次):%s, 错误:%v", retry, url, lastErr)
}
// 实现单个网页下载的功能
func downloadWebPage(ctx context.Context, client *http.Client, urll string, dir string, depth int, downloadRes bool, concurrency int) error {
// 检查并添加协议前缀
if !strings.HasPrefix(strings.ToLower(urll), "http://") &&
!strings.HasPrefix(strings.ToLower(urll), "https://") {
urll = "http://" + urll
}
// 检查是否已下载
if _, exists := downloadedURLs.Load(urll); exists {
return nil
}
downloadedURLs.Store(urll, true)
// 创建请求
req, err := http.NewRequest("GET", urll, nil)
if err != nil {
return fmt.Errorf("创建请求失败:%v", err)
}
// 设置请求头
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")
req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")
req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8")
req.Header.Set("Connection", "keep-alive")
//解析URL
parsedUrl, err := url.Parse(urll)
if err != nil {
return fmt.Errorf("解析URL失败:%s", urll)
}
// 根据URL路径创建目录结构
urlPath := parsedUrl.Path
if urlPath == "" || urlPath == "/" {
urlPath = "/index.html"
}
// 构建目标目录路径:基础目录 + 域名 + URL路径的目录部分
hostname := strings.TrimPrefix(parsedUrl.Hostname(), "www.")
// 安全处理目录路径
urlPath = sanitizeFileName(urlPath)
targetDir := filepath.Join(dir, sanitizeFileName(hostname), filepath.Dir(urlPath))
// 创建目录
if err := os.MkdirAll(targetDir, os.ModePerm); err != nil {
return fmt.Errorf("创建目录失败:%s", targetDir)
}
// 获取文件名
filename := filepath.Base(parsedUrl.Path)
if filename == "" || filename == "." || filename == "/" || filename == "\\" {
filename = "index.html"
}
// 安全处理文件名
filename = sanitizeFileName(filename)
// 构建唯一文件名:域名_文件名
// hostname := strings.TrimPrefix(parsedUrl.Hostname(), "www.")
// uniqueFilename := fmt.Sprintf("%s_%s", hostname, filename)
// 构建完整的文件路径
fullPath := filepath.Join(targetDir, filename)
progressFile := fullPath + ".progress"
// 检查是否有未完成的下载
progress, err := loadProgress(progressFile)
if err != nil {
return fmt.Errorf("加载进度失败:%v", err)
}
// 如果文件已存在,检查是否需要继续下载
if fileExists(fullPath) {
if progress == nil {
// 文件存在但没有进度文件,创建新文件名
for i := 1; ; i++ {
ext := path.Ext(filename)
basename := strings.TrimSuffix(filename, ext)
fullPath = filepath.Join(targetDir, fmt.Sprintf("%s_%d%s", basename, i, ext))
progressFile = fullPath + ".progress"
if !fileExists(fullPath) {
break
}
}
} else {
// 有进度文件,检查是否支持断点续传
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", progress.Downloaded))
}
}
// 发送请求
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("下载失败:%v", err)
}
defer resp.Body.Close()
// 如果需要继续递归,先提取链接
var resources, links []string
if depth > 0 {
// 创建 TeeReader 同时读取和写入
var buf bytes.Buffer
teeReader := io.TeeReader(resp.Body, &buf)
// 提取资源和链接
resources, links = extractResources(teeReader, parsedUrl)
// 使用缓冲的内容创建新的 Reader 用于保存文件
resp.Body = io.NopCloser(&buf)
}
// 创建或打开文件
var file *os.File
if progress != nil {
file, err = os.OpenFile(fullPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
} else {
file, err = os.Create(fullPath)
progress = &DownloadProgress{
URL: urll,
LocalPath: fullPath,
ContentLength: resp.ContentLength,
Downloaded: 0,
StartTime: time.Now(),
}
// 立即保存初始进度
if err := saveProgress(progress, progressFile); err != nil {
fmt.Printf("保存初始进度失败:%v\n", err)
}
}
if err != nil {
return fmt.Errorf("创建文件失败:%s", fullPath)
}
defer func() {
if file != nil {
file.Sync() // 确保数据写入磁盘
file.Close()
}
}()
// 使用父 context 创建子context
downloadCtx, cancel := context.WithCancel(ctx)
defer cancel()
// 创建一个等待组来确保信号处理完成
var signalWg sync.WaitGroup
signalWg.Add(1)
// 启动信号处理 goroutine
go func() {
defer signalWg.Done()
<-downloadCtx.Done()
if progress != nil {
file.Sync()
if err := saveProgress(progress, progressFile); err != nil {
fmt.Printf("保存进度失败:%v\n", err)
}
}
}()
// 创建带进度的写入器
progressWriter := &ProgressWriter{
Writer: file,
Progress: progress,
bar: progressbar.NewOptions64(
progress.ContentLength,
progressbar.OptionSetDescription("下载中"),
progressbar.OptionShowBytes(true),
progressbar.OptionSetWidth(50),
),
ctx: downloadCtx,
}
//写入文件
_, err = io.Copy(progressWriter, resp.Body)
if err != nil {
if err == context.Canceled {
fmt.Println("\n下载已中断")
// 确保在中断时保存最终进度
file.Sync()
// 尝试读取文件大小
if fi, err := file.Stat(); err == nil {
progress.Downloaded = fi.Size()
}
if err := saveProgress(progress, progressFile); err != nil {
fmt.Printf("保存最终进度失败:%v\n", err)
}
// 再次确保文件已同步
file.Sync()
// 等待信号处理完成
signalWg.Wait()
fmt.Printf("已保存进度,当前下载大小:%d 字节\n", progress.Downloaded)
return context.Canceled
}
// 只有在非取消的错误情况下才删除文件
file.Close()
if err != context.Canceled {
os.Remove(fullPath)
os.Remove(progressFile)
}
return fmt.Errorf("写入文件失败:%s", fullPath)
}
//只有在完全下载完成时才删除进度文件
defer func() {
if progress.Downloaded == progress.ContentLength {
file.Sync()
os.Remove(progressFile)
}
}()
// 处理提取的资源和链接
if depth > 0 {
var wg sync.WaitGroup
// 下载资源和链接的通道
downloadChan := make(chan string, len(resources)+len(links))
for i := 0; i < concurrency; i++ {
go func() {
for url := range downloadChan {
select {
case <-downloadCtx.Done():
wg.Done()
return
default:
err := downloadWebPage(downloadCtx, client, url, dir, depth-1, downloadRes, concurrency)
if err != nil {
fmt.Printf("下载失败 %s: %v\n", url, err)
}
wg.Done()
}
}
}()
}
// 添加资源下载任务
if downloadRes {
for _, resURL := range resources {
select {
case <-downloadCtx.Done():
return context.Canceled
default:
resUrlObj, err := url.Parse(resURL)
if err == nil && resUrlObj.Host != "" {
wg.Add(1)
downloadChan <- resUrlObj.String()
}
}
}
}
// 添加链接下载任务
if downloadRes {
for _, link := range links {
select {
case <-downloadCtx.Done():
return context.Canceled
default:
linkURL, err := url.Parse(link)
if err == nil && linkURL.Host != "" {
wg.Add(1)
downloadChan <- linkURL.String()
}
}
}
}
// 关闭下载通道
close(downloadChan)
// 使用带超时的等待
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-downloadCtx.Done():
return context.Canceled
case <-done:
return nil
}
}
return nil
}
// 判断文件是否存在
func fileExists(filepath string) bool {
_, err := os.Stat(filepath)
return !os.IsNotExist(err)
}
// 解析HTML并提取资源链接
func extractResources(body io.Reader, baseURL *url.URL) ([]string, []string) {
var resources, links []string
tokenizer := html.NewTokenizer(body)
for {
tokenType := tokenizer.Next()
if tokenType == html.ErrorToken {
break
}
token := tokenizer.Token()
switch token.Data {
case "img", "script", "link":
// 提取资源URL
for _, attr := range token.Attr {
if attr.Key == "src" || attr.Key == "href" {
resURL, err := baseURL.Parse(attr.Val)
if err == nil {
resources = append(resources, resURL.String())
}
}
}
case "a":
// 提取链接URL
for _, attr := range token.Attr {
if attr.Key == "href" {
linkURL, err := baseURL.Parse(attr.Val)
if err == nil && strings.HasPrefix(linkURL.String(), baseURL.String()) {
links = append(links, linkURL.String())
}
}
}
}
}
// URL 编码处理
for i, resURL := range resources {
if u, err := url.QueryUnescape(resURL); err == nil {
resources[i] = u
}
}
for i, link := range links {
if u, err := url.QueryUnescape(link); err == nil {
links[i] = u
}
}
return resources, links
}
type DownloadProgress struct {
URL string `json:"url"`
LocalPath string `json:"local_path"`
ContentLength int64 `json:"content_length"`
Downloaded int64 `json:"downloaded"`
StartTime time.Time `json:"start_time"`
}
// 保存下载进度
func saveProgress(progress *DownloadProgress, progressFile string) error {
// 确保目录存在
dir := filepath.Dir(progressFile)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("创建进度文件目录失败:%v", err)
}
data, err := json.Marshal(progress)
if err != nil {
return err
}
// 直接写入进度文件,不使用临时文件
if err := os.WriteFile(progressFile, data, 0644); err != nil {
return err
}
return nil
}
// 加载下载进度
func loadProgress(progressFile string) (*DownloadProgress, error) {
data, err := os.ReadFile(progressFile)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var progress DownloadProgress
if err := json.Unmarshal(data, &progress); err != nil {
return nil, err
}
return &progress, nil
}
// 进度写入器
//
// type ProgressWriter struct {
// Writer io.Writer
// Progress *DownloadProgress
// OnWrite func(*DownloadProgress)
// ctx context.Context
// }
type ProgressWriter struct {
Writer io.Writer
Progress *DownloadProgress
bar *progressbar.ProgressBar
ctx context.Context
}
func NewProgressWriter(w io.Writer, contentLength int64) *ProgressWriter {
bar := progressbar.NewOptions64(
contentLength,
progressbar.OptionSetDescription("下载中"),
progressbar.OptionShowBytes(true),
progressbar.OptionSetWidth(50),
)
return &ProgressWriter{
Writer: w,
Progress: &DownloadProgress{ContentLength: contentLength},
bar: bar,
}
}
func (pw *ProgressWriter) Write(p []byte) (n int, err error) {
// 检查是否需要中断
select {
case <-pw.ctx.Done():
// 写入最后的数据
n, err = pw.Writer.Write(p)
if err == nil {
pw.Progress.Downloaded += int64(n)
if err = pw.bar.Set64(pw.Progress.Downloaded); err != nil {
return n, err
}
}
return n, context.Canceled
default:
n, err = pw.Writer.Write(p)
if err == nil {
pw.Progress.Downloaded += int64(n)
if err = pw.bar.Set64(pw.Progress.Downloaded); err != nil {
return n, err
}
// // 计算并显示进度
// if pw.Progress.ContentLength > 0 {
// percent := float64(pw.Progress.Downloaded) / float64(pw.Progress.ContentLength) * 100
// speed := float64(pw.Progress.Downloaded) / time.Since(pw.Progress.StartTime).Seconds() / 1024 // KB/s
// // 强制刷新输出
// fmt.Fprintf(os.Stdout, "\r正在下载 %s... %.2f%% (%.2f KB/s) ",
// filepath.Base(pw.Progress.LocalPath),
// percent,
// speed)
// os.Stdout.Sync()
// } else {
// // 未知文件大小,只显示已下载大小和速度
// downloadedMB := float64(pw.Progress.Downloaded) / 1024 / 1024
// speed := float64(pw.Progress.Downloaded) / time.Since(pw.Progress.StartTime).Seconds() / 1024 // KB/s
// fmt.Fprintf(os.Stdout, "\r正在下载 %s... %.2f MB (%.2f KB/s) ",
// filepath.Base(pw.Progress.LocalPath),
// downloadedMB,
// speed)
// os.Stdout.Sync()
// }
// // 如果下载完成,打印换行符
// if pw.Progress.Downloaded == pw.Progress.ContentLength {
// fmt.Fprintln(os.Stdout)
// }
// // 强制刷新标准输出
// os.Stdout.Sync()
}
}
return
}
// 添加一个新的函数来处理文件名
func sanitizeFileName(filename string) string {
// 替换 Windows 不允许的字符
invalid := []string{"<", ">", ":", "\"", "/", "\\", "|", "?", "*", "+"}
result := filename
for _, char := range invalid {
result = strings.ReplaceAll(result, char, "_")
}
return result
}
以上代码还有一些小bug,1.ctl+c结束下载速度较慢 2.无法实时显示下载进度。
后续会解决这些bug