go实现并发文件哈希计算小工具

242 阅读4分钟

基础核心功能

  1. 支持多种哈希算法

    • 包括 MD5SHA1SHA256SHA512 等
  2. 文件/目录输入

    • 支持单个文件路径或目录路径作为输入
    • 支持递归遍历子目录
  3. 并发处理

    • 使用 goroutine + channel 实现并行计算
    • 控制最大并发数防止资源耗尽
  4. 结果输出

    • 默认输出格式:[文件路径] [哈希值] [算法名称]
    • 支持简洁模式(仅哈希值)和详细模式(含错误信息)
  5. 错误处理

    • 文件不存在/权限问题的明确报错
    • 中断信号捕获(Ctrl+C)

扩展功能

  1. 哈希验证

    • 支持通过校验文件(如 .md5)进行自动验证
    • 支持通过命令行参数指定预期哈希值
  2. 并行控制

    • 限制同时打开的文件描述符数量
    • 大文件分块读取避免内存溢出
  3. 输出格式

    • 支持将结果写入指定文件

图示如下:

image.png

image.png

image.png

package main

import (
	"crypto/md5"
	"crypto/sha1"
	"crypto/sha256"
	"crypto/sha512"
	"encoding/hex"
	"fmt"
	"hash"
	"io"
	"os"
	"path/filepath"
	"strings"
	"sync"

	"github.com/spf13/cobra"
)

// 选项结构体
type hashOptions struct {
	algorithm  string
	path       string
	output     string
	comparable string
	recursive  string
}

// 创建根命令
func newRootCmd() *cobra.Command {
	return &cobra.Command{
		Use:   "hashapp",
		Short: "filehash is a tool to calculate the hash of a file",
		Long:  `filehash is a tool to calculate the hash of a file`,
	}
}

// 创建哈希子命令
func newHashCmd(opts *hashOptions) *cobra.Command {
	cmd := &cobra.Command{
		Use:          "hash",
		Short:        "hash 文件或目录中的文件",
		SilenceUsage: true,
		PreRunE: func(cmd *cobra.Command, args []string) error {
			fileFlag := cmd.Flags().Lookup("path")
			if !fileFlag.Changed {
				cmd.Help()
				return fmt.Errorf("错误: 必须提供 -p/--path 参数")
			}
			return nil
		},
		RunE: func(cmd *cobra.Command, args []string) error {
			fmt.Printf("处理文件: %s\n使用算法: %s\n输出到: %s\n递归模式: %v\n",
				opts.path, opts.algorithm, opts.output, opts.recursive)
			if opts.output == "" {
				_, err := hashPath(opts.path, opts.algorithm, opts.recursive)
				if err != nil {
					return fmt.Errorf("处理文件失败: %v", err)
				}
			} else {
				results, err := hashPath(opts.path, opts.algorithm, opts.recursive)
				if err != nil {
					return fmt.Errorf("处理文件失败: %v", err)
				}
				if err := writeResults(results, opts.output); err != nil {
					return fmt.Errorf("写入文件失败: %v", err)
				}
			}
			return nil
		},
	}

	// 设置标志
	cmd.Flags().StringVarP(&opts.algorithm, "algorithm", "a", "md5", "支持的哈希算法:md5默认,sha1,sha256,sha512")
	cmd.Flags().StringVarP(&opts.path, "path", "p", "", "要计算哈希的文件")
	cmd.Flags().StringVarP(&opts.output, "output", "o", "", "输出文件路径")
	cmd.Flags().StringVarP(&opts.recursive, "recursive", "r", "false", "递归模式,默认false")
	// cmd.Flags().StringVarP(&opts.comparable, "comparable", "c", "", "可比较的哈希值")

	// 设置错误处理
	cmd.SilenceErrors = true

	return cmd
}

// 创建compare子命令
func newCompareCmd(opts *hashOptions) *cobra.Command {
	cmd := &cobra.Command{
		Use:          "compare",
		Short:        "比较文件哈希值",
		SilenceUsage: true,
		PreRunE: func(cmd *cobra.Command, args []string) error {
			if opts.comparable == "" {
				cmd.Help()
				return fmt.Errorf("错误: 必须提供 -c/--comparable 参数")
			}
			if opts.path == "" {
				cmd.Help()
				return fmt.Errorf("错误: 必须提供 -p/--path 参数")
			}
			return nil
		},
		RunE: func(cmd *cobra.Command, args []string) error {
			fmt.Printf("比较哈希值: %s\n", opts.comparable)
			hash, err := hashFile(opts.path, opts.algorithm)
			if err != nil {
				return fmt.Errorf("计算哈希失败: %v", err)
			}
			result := compareHash(hash, opts.comparable)
			if result {
				fmt.Println("哈希值匹配")
			} else {
				fmt.Println("哈希值不匹配")
			}
			return nil
		},
	}

	// 设置标志
	cmd.Flags().StringVarP(&opts.path, "path", "p", "", "要计算哈希的文件")
	cmd.Flags().StringVarP(&opts.comparable, "comparable", "c", "", "可比较的哈希值")
	cmd.Flags().StringVarP(&opts.algorithm, "algorithm", "a", "md5", "支持的哈希算法:md5默认,sha1,sha256,sha512")
	// 设置错误处理
	cmd.SilenceErrors = true

	return cmd
}

func main() {
	// 创建选项实例
	opts := &hashOptions{
		algorithm: "md5", // 默认值
	}

	// 创建根命令
	rootCmd := newRootCmd()

	// 添加哈希子命令
	rootCmd.AddCommand(newHashCmd(opts))
	rootCmd.AddCommand(newCompareCmd(opts))

	// 执行命令
	if err := rootCmd.Execute(); err != nil {
		fmt.Println("Error:", err)
		os.Exit(1)
	}
}

// 单文件hash计算
func hashFile(file string, algorithm string) (string, error) {
	// 打开文件
	f, err := os.Open(file)
	if err != nil {
		return "", fmt.Errorf("打开文件失败: %v", err)
	}
	defer f.Close() // 确保在函数返回时关闭文件

	// 选择哈希算法
	var h hash.Hash
	switch strings.ToLower(algorithm) {
	case "md5":
		h = md5.New()
	case "sha1":
		h = sha1.New()
	case "sha256":
		h = sha256.New()
	case "sha512":
		h = sha512.New()
	default:
		return "", fmt.Errorf("不支持的哈希算法: %s", algorithm)
	}

	// 计算哈希
	if _, err := io.Copy(h, f); err != nil {
		return "", fmt.Errorf("计算哈希失败: %v", err)
	}

	// 获取哈希值并转换为十六进制字符串
	hashValue := hex.EncodeToString(h.Sum(nil))
	return hashValue, nil
}

// 定义结果结构体
type hashResult struct {
	path string
	hash string
	err  error
}

// 递归文件遍历(添加并发控制)
func algorithmFiles(dir string, algorithm string, recursive string) ([]hashResult, error) {
	var results []hashResult
	// 创建并发控制
	workers := 5                        // 并发工作协程数
	resultChan := make(chan hashResult) // 结果通道
	done := make(chan struct{})         // 完成信号通道
	var wg sync.WaitGroup               // 等待组

	// 启动结果收集协程
	go func() {
		for result := range resultChan {
			if result.err != nil {
				fmt.Printf("警告: 处理文件 %s 失败: %v\n", result.path, result.err)
				continue
			}
			fmt.Printf("%s: %s\n", result.path, result.hash)
			results = append(results, result)
		}
		close(done)
	}()

	// 创建工作协程池
	jobs := make(chan string, workers)
	for i := 0; i < workers; i++ {
		go func() {
			for path := range jobs {
				hash, err := hashFile(path, algorithm)
				resultChan <- hashResult{path: path, hash: hash, err: err}
				wg.Done()
			}
		}()
	}

	// 递归遍历文件
	var walk func(string) error
	walk = func(dir string) error {
		files, err := os.ReadDir(dir)
		if err != nil {
			return fmt.Errorf("读取目录失败: %v", err)
		}

		for _, file := range files {
			fullPath := filepath.Join(dir, file.Name())
			fileInfo, err := file.Info()
			if err != nil {
				fmt.Printf("警告: 无法获取文件信息 %s: %v\n", fullPath, err)
				continue
			}

			if fileInfo.IsDir() {
				if recursive == "true" {
					if err := walk(fullPath); err != nil {
						fmt.Printf("警告: 处理目录 %s 失败: %v\n", fullPath, err)
					}
				}
				continue
			}

			wg.Add(1)
			jobs <- fullPath
		}
		return nil
	}

	// 开始遍历
	if err := walk(dir); err != nil {
		return results, err
	}

	// 等待所有任务完成
	wg.Wait()
	close(jobs)
	close(resultChan)
	<-done

	return results, nil
}

// 根据path是否为文件或目录进行不同处理
func hashPath(path string, algorithm string, recursive string) ([]hashResult, error) {
	fileInfo, err := os.Stat(path)
	if err != nil {
		return nil, fmt.Errorf("获取文件信息失败: %v", err)
	}
	if fileInfo.IsDir() {
		return algorithmFiles(path, algorithm, recursive)
	}
	hash, err := hashFile(path, algorithm)
	if err != nil {
		return nil, fmt.Errorf("计算哈希失败: %v", err)
	} else {
		fmt.Printf("%s: %s\n", path, hash)
		return []hashResult{{path: path, hash: hash, err: nil}}, nil
	}

}

// 比较hash
func compareHash(hash1, hash2 string) bool {
	return hash1 == hash2
}

// 将结果写入文件
func writeResults(results []hashResult, output string) error {
	if _, err := os.Stat(output); err == nil {
		// 文件存在,询问用户是否覆盖
		return fmt.Errorf("文件 %s 已存在,请指定其他文件名或删除现有文件", output)
	}
	f, err := os.Create(output)
	if err != nil {
		return fmt.Errorf("创建输出文件失败: %v", err)
	}
	defer f.Close()
	for _, result := range results {
		_, err := f.WriteString(fmt.Sprintf("%s: %s\n", result.path, result.hash))
		if err != nil {
			return fmt.Errorf("写入文件失败: %v", err)
		}
	}
	return nil
}