需求:拷贝目录,可被中断,并带进度callback。保持目录属性、文件属性一致。
cpdir.go代码如下:
package cpdir
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"time"
"github.com/cheggaaa/pb"
)
var (
RefreshRate = time.Second * 10
BufferSize = 4 * 1024 * 1024
)
// CopyDir copy dir with context and progress
func CopyDir(ctx context.Context, src string, dstFunc func(string) string,
callback func(*pb.ProgressBar, string)) error {
var totalSize int64
var files []string
// 遍历目录树:1.在dst创建同样的目录树;2.获取所有文件、统计总大小;
err := filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("walk error at path %s, %v", path, err)
}
if info.IsDir() {
// create dstPath
dstPath := dstFunc(path)
if err = os.MkdirAll(dstPath, info.Mode()); nil != err {
return fmt.Errorf("mkdir failed for path %s, %v", dstPath, err)
}
} else {
totalSize += info.Size()
files = append(files, path)
}
return nil
})
if err != nil {
return fmt.Errorf("walk path %s failed, %v", src, err)
}
log.Infof("start copy %d files (%s) for dir %s", len(files), formatBytes(totalSize), src)
// create bar
bar := pb.New64(totalSize).SetUnits(pb.U_BYTES).SetRefreshRate(RefreshRate).SetWidth(0)
bar.ShowSpeed = true
if callback != nil {
bar.Callback = func(s string) { callback(bar, s) }
}
bar.Start()
defer bar.Finish()
buf := make([]byte, BufferSize)
for _, filePath := range files {
err = copyFile(ctx, filePath, dstFunc(filePath), buf, bar)
if err != nil {
break
}
}
select {
case <-ctx.Done():
return ctx.Err()
default:
return err
}
}
func copyFile(ctx context.Context, srcPath string, dstPath string, buf []byte, bar *pb.ProgressBar) error {
srcFile, err := os.Open(srcPath)
if err != nil {
return err
}
defer srcFile.Close()
// get source size
srcFileInfo, err := srcFile.Stat()
if err != nil {
return fmt.Errorf("stat failed for file %s, %v", srcPath, err)
}
dstFile, err := os.Create(dstPath)
if err != nil {
return fmt.Errorf("create failed for file %s, %v", dstPath, err)
}
defer dstFile.Close()
// copy mode
err = os.Chmod(dstPath, srcFileInfo.Mode())
if err != nil {
return fmt.Errorf("chmod failed for file %s, %v", dstPath, err)
}
// create proxy writer with hook (ctx, progress bar)
writer := newContextWriter(ctx, bar.NewProxyWriter(dstFile))
_, err = io.CopyBuffer(writer, srcFile, buf)
if err != nil {
return fmt.Errorf("copy failed for file %s, %v", dstPath, err)
}
return nil
}
type contextWriter struct {
ctx context.Context
writer io.Writer
}
func newContextWriter(ctx context.Context, writer io.Writer) io.Writer {
return &contextWriter{ctx: ctx, writer: writer}
}
func (cw *contextWriter) Write(p []byte) (int, error) {
select {
case <-cw.ctx.Done():
return 0, cw.ctx.Err()
default:
return cw.writer.Write(p)
}
}
func formatBytes(bytes int64) string {
const (
KB = 1 << (10 * (iota + 1))
MB
GB
TB
PB
)
unit := ""
value := float64(bytes)
switch {
case bytes >= PB:
unit = "PB"
value /= PB
case bytes >= TB:
unit = "TB"
value /= TB
case bytes >= GB:
unit = "GB"
value /= GB
case bytes >= MB:
unit = "MB"
value /= MB
case bytes >= KB:
unit = "KB"
value /= KB
default:
unit = "B"
return fmt.Sprintf("%.0f %s", value, unit)
}
return fmt.Sprintf("%.2f %s", value, unit)
}
使用示例:
cpdir_test.go
package cpdir
import (
"context"
"strings"
"testing"
"time"
"github.com/cheggaaa/pb"
)
func TestCopyDirLocal(t *testing.T) {
src := "D:\\Download"
RefreshRate = time.Second * 1
err := CopyDir(context.TODO(), src, func(path string) string {
return strings.ReplaceAll(path, "Download", "Download2")
}, func(bar *pb.ProgressBar, s string) {
t.Logf("progress: %s", s)
})
t.Log(err)
}
示例二:
demo.go
func (s *Syncer) handleTask(task *Task) {
log.Infof("Task %s start...", task.Name)
callback, stopProgressUpdate := s.startProgressUpdate(task)
defer stopProgressUpdate()
// 这里传入拷贝地址srcPath与目标地址函数dstFunc
cpdir.CopyDir(task.ctx, srcPath, dstFunc, callback)
}
func (s *Syncer) startProgressUpdate(task *Task) (
callback func(*pb.ProgressBar, string), stop func()) {
progressCh := make(chan string, 10) //nolint:gomnd
callback = func(bar *pb.ProgressBar, s string) {
log.Infof("task %s progress: %s", task.Name, s)
select {
case progressCh <- s:
default:
}
}
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for progress := range progressCh {
// 这里实现业务逻辑
s.updateProgress(task, progress)
}
}()
stop = func() {
close(progressCh)
wg.Wait()
}
return
}