golang拷贝目录(带中断与进度)

52 阅读2分钟

需求:拷贝目录,可被中断,并带进度callback。保持目录属性、文件属性一致。

cpdir.go代码如下:

package cpdir

import (
    "context"
    "fmt"
    "io"
    "os"
    "path/filepath"
    "time"

    "github.com/cheggaaa/pb"
)

var (
    RefreshRate = time.Second * 10
    BufferSize  = 4 * 1024 * 1024
)

// CopyDir copy dir with context and progress
func CopyDir(ctx context.Context, src string, dstFunc func(string) string,
    callback func(*pb.ProgressBar, string)) error {
    var totalSize int64
    var files []string

    // 遍历目录树:1.在dst创建同样的目录树;2.获取所有文件、统计总大小;
    err := filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
        if err != nil {
            return fmt.Errorf("walk error at path %s, %v", path, err)
        }
        if info.IsDir() {
            // create dstPath
            dstPath := dstFunc(path)
            if err = os.MkdirAll(dstPath, info.Mode()); nil != err {
                    return fmt.Errorf("mkdir failed for path %s, %v", dstPath, err)
            }
        } else {
            totalSize += info.Size()
            files = append(files, path)
        }
        return nil
    })
    if err != nil {
        return fmt.Errorf("walk path %s failed, %v", src, err)
    }
    log.Infof("start copy %d files (%s) for dir %s", len(files), formatBytes(totalSize), src)

    // create bar
    bar := pb.New64(totalSize).SetUnits(pb.U_BYTES).SetRefreshRate(RefreshRate).SetWidth(0)
    bar.ShowSpeed = true
    if callback != nil {
        bar.Callback = func(s string) { callback(bar, s) }
    }
    bar.Start()
    defer bar.Finish()

    buf := make([]byte, BufferSize)

    for _, filePath := range files {
        err = copyFile(ctx, filePath, dstFunc(filePath), buf, bar)
        if err != nil {
            break
        }
    }

    select {
    case <-ctx.Done():
        return ctx.Err()
    default:
        return err
    }
}

func copyFile(ctx context.Context, srcPath string, dstPath string, buf []byte, bar *pb.ProgressBar) error {
    srcFile, err := os.Open(srcPath)
    if err != nil {
        return err
    }
    defer srcFile.Close()

    // get source size
    srcFileInfo, err := srcFile.Stat()
    if err != nil {
        return fmt.Errorf("stat failed for file %s, %v", srcPath, err)
    }

    dstFile, err := os.Create(dstPath)
    if err != nil {
        return fmt.Errorf("create failed for file %s, %v", dstPath, err)
    }
    defer dstFile.Close()

    // copy mode
    err = os.Chmod(dstPath, srcFileInfo.Mode())
    if err != nil {
        return fmt.Errorf("chmod failed for file %s, %v", dstPath, err)
    }

    // create proxy writer with hook (ctx, progress bar)
    writer := newContextWriter(ctx, bar.NewProxyWriter(dstFile))

    _, err = io.CopyBuffer(writer, srcFile, buf)
    if err != nil {
        return fmt.Errorf("copy failed for file %s, %v", dstPath, err)
    }
    return nil
}


type contextWriter struct {
    ctx    context.Context
    writer io.Writer
}

func newContextWriter(ctx context.Context, writer io.Writer) io.Writer {
    return &contextWriter{ctx: ctx, writer: writer}
}

func (cw *contextWriter) Write(p []byte) (int, error) {
    select {
    case <-cw.ctx.Done():
        return 0, cw.ctx.Err()
    default:
        return cw.writer.Write(p)
    }
}


func formatBytes(bytes int64) string {
    const (
        KB = 1 << (10 * (iota + 1))
        MB
        GB
        TB
        PB
    )

    unit := ""
    value := float64(bytes)

    switch {
    case bytes >= PB:
        unit = "PB"
        value /= PB
    case bytes >= TB:
        unit = "TB"
        value /= TB
    case bytes >= GB:
        unit = "GB"
        value /= GB
    case bytes >= MB:
        unit = "MB"
        value /= MB
    case bytes >= KB:
        unit = "KB"
        value /= KB
    default:
        unit = "B"
        return fmt.Sprintf("%.0f %s", value, unit)
    }

    return fmt.Sprintf("%.2f %s", value, unit)
}

使用示例:

cpdir_test.go

package cpdir

import (
    "context"
    "strings"
    "testing"
    "time"

    "github.com/cheggaaa/pb"
)

func TestCopyDirLocal(t *testing.T) {
    src := "D:\\Download"
    RefreshRate = time.Second * 1
    err := CopyDir(context.TODO(), src, func(path string) string {
            return strings.ReplaceAll(path, "Download", "Download2")
    }, func(bar *pb.ProgressBar, s string) {
            t.Logf("progress: %s", s)
    })
    t.Log(err)
}

示例二:

demo.go

func (s *Syncer) handleTask(task *Task) {
    log.Infof("Task %s start...", task.Name)

    callback, stopProgressUpdate := s.startProgressUpdate(task)
    defer stopProgressUpdate()
    
    // 这里传入拷贝地址srcPath与目标地址函数dstFunc
    cpdir.CopyDir(task.ctx, srcPath, dstFunc, callback)
}


func (s *Syncer) startProgressUpdate(task *Task) (
    callback func(*pb.ProgressBar, string), stop func()) {
    progressCh := make(chan string, 10) //nolint:gomnd
    callback = func(bar *pb.ProgressBar, s string) {
        log.Infof("task %s progress: %s", task.Name, s)
        select {
        case progressCh <- s:
        default:
        }
    }

    wg := &sync.WaitGroup{}
    wg.Add(1)
    go func() {
        defer wg.Done()
        for progress := range progressCh {
            // 这里实现业务逻辑
            s.updateProgress(task, progress)
        }
    }()

    stop = func() {
        close(progressCh)
        wg.Wait()
    }
    return
}