Go实现多线程分片下载文件

889 阅读18分钟

我们在下载大文件时,通常会使用多线程下载的方式来加快下载速度。例如常用的多线程下载工具(Gopeed、Aria2、XDM等等),都是通过多线程下载技术充分利用了网络带宽,以提高下载速度。

那么多线程下载是怎么实现的呢?多个线程发送网络请求,是怎么做到同时下载一个文件呢?事实上,借助HTTP协议中的一些机制就可以实现了!

今天我们就通过使用Go语言为例,从了解HTTP请求相关的一些机制开始,实现一个多线程下载的示例。

1,多线程下载原理

事实上,多线程下载的原理很简单,主要的步骤如下:

  • 获取待下载文件大小
  • 创建一个大小与待下载文件大小相同的空白文件
  • 每个线程下载文件的一部分,并同时写入空白文件的不同位置

实现这些步骤,就涉及到HTTP协议的下列相关机制。

当然,一个线程下载文件的一部分并单独保存为文件,最后拼接为完整文件也是可以的。

(1) HEAD请求 - 只获取响应头

我们通常发送HTTP请求大多数是GET或者POST类型,发送请求后我们会立即获取响应体,浏览器则会根据响应体的类型来处理内容,例如返回的是text/html就会作为网页显示,返回image/png就会解码为图片等等,响应体的类型由响应头Content-Type标识。当我们下载文件时,事实上也是发送HTTP请求,只不过服务器返回的响应体就是文件本身了!其类型则是application/octet-stream,浏览器也知道这是个文件需要下载。

当然,文件作为响应体通常比起网页、图片要大得多,在多线程下载时,我们就要先获取文件的大小,而不是立即获取文件本身,这时我们就可以向服务器发起HEAD请求而不是GET请求。

服务器收到HEAD请求后,就只会返回对应的响应头,而不会返回响应体,这样我们就可以在下载文件之前,读取响应头中的Content-Length来先获取待下载文件大小。

(2) Range请求头 - 只获取部分响应体

知道了文件大小,我们就需要让每个线程只下载一部分文件,借助HTTP的Range请求头,就可以实现只让服务端返回响应体内容的一部分,而不是返回完整的响应体。

这里我们先来借助书籍《图解HTTP》中对Range请求头的讲解,来学习一下:

可见当我们发送一个请求获取内容时,如果指定了如下请求头:

Range: bytes=5001-10000

那么服务端就只会返回响应体的第5001到第10000字节的内容部分,包含第5001和第10000字节,0表示响应体的第一个字节,其HTTP状态码为206

这样,在多个线程同时下载文件时,我们在每个线程的请求中使用Range请求头,就可以实现一个线程只下载文件的一部分了!

(3) Accept-Ranges响应头 - 判断是否支持部分获取

然而,并非所有的服务器都支持通过Range请求头来获取指定部分的内容的。因此,在使用多线程下载之前,我们需要判断一下下载的目标文件能否通过Range部分获取。

我们发送HEAD请求后,通过判断响应头中的Accept-Ranges就可以查看该请求是否支持部分获取,通常有下列情况:

  • 响应头中没有Accept-Ranges字段或者其值为none,说明该请求的响应内容不支持部分获取
  • 响应头中的Accept-Ranges值为bytes,说明该请求的响应内容支持部分获取

事实上,在我们使用浏览器下载文件时,浏览器也会通过检查Accept-Ranges响应头来判断文件能否断点续传的。

(4) 为什么多线程下载可以提升速度?

事实上,在我们客户端(下载文件的)和服务端双向网络通信情况都很好的情况下,使用单线程和多线程下载的速度是几乎没有差异的,也就是说能够跑满我们客户端的全部带宽,那么这种情况下我们使用单线程下载反而更能够节省硬件和网络资源。

但是在我们客户端和服务端之间网络波动较大的情况下,例如我们国内从Github下载文件的时候,就会发现多线程下载速度比单线程快得多,反之使用单线程完全无法充分利用我们的网络带宽。

这种现象事实上是因为TCP连接的慢启动机制导致的,众所周知HTTP是基于TCP的协议,每次我们建立HTTP连接时,包括下载文件,都是在传输层基于TCP协议进行传输。TCP慢启动机制是TCP 协议中一种拥塞控制的机制,目的是在开始数据传输时逐步探测网络的容量,避免瞬间发送大量数据而导致网络拥塞。慢启动不是字面意义上的“慢”,而是相对于立即使用最大带宽而言,它会逐渐增加传输速率。

慢启动机制的过程简要概括如下:

  • 一开始建立连接:当一个新的TCP连接建立后,发送方并不知道当前网络的拥塞情况。因此,发送方不会马上发送大量数据,而是会使用慢启动机制来逐步增加数据传输的速率,在TCP中使用阻塞窗口cwnd来限制发送的数据量,也就是说一开始cwnd是非常小的
  • 拥塞窗口增长:在建立连接后,每当接收到一个确认ACK包时,cwnd指数级增长,直到达到网络的带宽限制或者某个拥塞控制的阈值(称为慢启动阈值ssthresh),这个过程会一直持续,直到发送方探测到网络出现了拥塞(比如丢包或者确认延迟变长),或者cwnd达到了某个预定义的慢启动阈值ssthresh
  • 慢启动的终止:慢启动机制会在以下情况终止:
    • 达到慢启动阈值ssthresh:当拥塞窗口cwnd增长到慢启动阈值ssthresh时,慢启动机制停止,此时TCP会进入另一种拥塞控制机制,称为拥塞避免,这时cwnd增长变为线性而非指数级
    • 发生拥塞(如丢包或超时):如果发送方检测到数据包丢失(例如没有收到确认),它会认为网络已经出现拥塞,此时ssthresh会被调整为当前cwnd的一半,然后cwnd会重置为1 MSS,重新进入慢启动阶段

可见TCP连接使用cwnd限制两者发送的数据量的大小,并逐步“试探”两者传输数据速率的上限并增加传输的数据量。

在我们下载文件时,事实上是服务端在向我们发送文件,如果网络波动较大、不稳定,TCP连接机会一直将cwnd限制在一个较小的值,在单位时间内,服务端也无法向我们发送更大的数据量。

此时,如果我们使用多线程下载,和服务端建立多个TCP连接,这样即使每个TCP连接的cwnd较小,所有TCP连接加起来传输的数据量仍然可以占满我们的带宽。

2,Go代码实现

知道了HTTP的上述几个机制,相信大家就知道如何实现一个简单的多线程下载了!我们可以总结主要步骤如下:

  • 发送HEAD类型请求,通过Content-Length请求头获取待下载文件大小
  • 根据给定的线程数量,结合待下载文件大小,确定每个线程下载的范围部分,也就是每个线程的Range请求头字节范围
  • 创建一个大小与待下载文件大小相同的空白文件
  • 启动所有线程,使得每个线程下载它们对应的部分内容,并将内容写入到空白文件的对应位置,直到全部下载完成

这里分别设计下列类(结构体),用于存放多线程下载时的传入参数和状态量:

上述ShardTask类表示一个线程的下载任务,其中会完成一个分片(文件的一部分)的下载请求操作,它有如下作为参数的属性:

  • Url 下载的文件地址
  • Order 分片序号
  • FilePath 下载的文件路径
  • RangeStartRangeEnd 下载的文件起始范围和结束范围,用于设定Range请求头

此外,还有作为下载状态的属性:

  • DownloadSize 下载任务进行时,这个线程已下载的文件部分大小
  • TaskDone 这个线程的下载任务是否完成

该类的成员方法如下:

  • DoShardGet 执行分片下载任务,在其中会根据RangeStartRangeEnd设定对应的HTTP请求头,并计算写入空白文件的位置,发送请求并下载对应的文件部分

然后就是ParallelGetTask类,表示一整个多线程下载任务,其中包含了一个多线程下载任务的参数和状态量,并且实现了多线程下载的每个步骤,它有如下作为参数的属性:

  • Url 文件的下载链接
  • FilePath 下载的文件路径
  • Concurrent 下载并发数,即同时下载的分片数量

此外还有作为状态的属性:

  • TotalSize 待下载文件的总大小
  • ShardTaskList 存储所有分片任务对象指针的列表

该类中的方法主要是分片下载的一些步骤如下:

  • getLength 发送HEAD请求获取Content-Length以获取文件大小,获取后将其设定到TotalSize属性
  • allocateTask 根据给定的线程数和获取到的文件大小,计算每个线程下载的文件内容范围,并创建对应的ShardTask结构体放入ShardTaskList
  • createFile 根据getLength获取到的待下载文件大小,创建一个相同大小的空白文件,用于接收下载内容
  • downloadShard 为每一个ShardTask对象创建一个线程(Goroutine)并在新的线程中调用ShardTask对象的下载分片方法,以启动所有线程的下载任务,并通过sync.WaitGroup来等待全部线程完成
  • Run 启动整个多线程下载任务,该函数是暴露的公开函数,其中对上述每个步骤函数进行了组织,按顺序调用执行

下面,我们来看一下它们的代码实现。

(1) ShardTask - 一个线程的下载任务

代码实现如下:

package main

import (
	"bufio"
	"fmt"
	"github.com/fatih/color"
	"io"
	"net/http"
	"os"
	"sync"
)

// 全局HTTP客户端
var httpClient = http.Client{
	// 关闭超时,否则会导致下载时间超过超时时间时,TCP连接被自动断开
	Timeout: 0,
	Transport: &http.Transport{
		// 从环境变量读取代理配置
		Proxy: http.ProxyFromEnvironment,
		// 关闭keep-alive确保一个线程就使用一个TCP连接
		DisableKeepAlives: true,
	},
}

// ShardTask 单个分片下载任务的任务参数和状态量
type ShardTask struct {
	// 下载链接
	Url string
	// 分片序号,从1开始
	Order int
	// 下载的文件路径
	FilePath string
	// 分片的起始范围(字节,包含)
	RangeStart int64
	// 分片的结束范围(字节,包含)
	RangeEnd int64
	// 已下载的部分(字节)
	DownloadSize int64
	// 该任务是否完成
	TaskDone bool
}

// NewShardTask 构造函数
func NewShardTask(url string, order int, filePath string, rangeStart int64, rangeEnd int64) *ShardTask {
	return &ShardTask{
		// 设定任务参数
		Url:        url,
		Order:      order,
		FilePath:   filePath,
		RangeStart: rangeStart,
		RangeEnd:   rangeEnd,
		// 初始化状态量
		DownloadSize: 0,
		TaskDone:     false,
	}
}

// DoShardGet 开始下载这个分片(该方法在goroutine中执行)
//
// 将会根据RangeStart和RangeEnd范围,在写入目的文件时,将文件指针Seek到RangeStart位置处开始写入文件
// 实现多个线程同时向一个文件不同位置分别并发写入内容
func (task *ShardTask) DoShardGet(waitGroup *sync.WaitGroup) {
	// 打开文件
	file, e := os.OpenFile(task.FilePath, os.O_WRONLY, 0755)
	if e != nil {
		color.Red("任务%d打开文件失败!", task.Order)
		color.HiRed("%s", e)
		waitGroup.Done()
		return
	}
	// 结束时关闭文件
	defer func() {
		_ = file.Close()
	}()
	// 设定写入偏移量
	_, e = file.Seek(task.RangeStart, io.SeekStart)
	if e != nil {
		color.Red("任务%d设定文件偏移量失败!", task.Order)
		color.HiRed("%s", e)
		waitGroup.Done()
		return
	}
	// 准备请求
	request, e := http.NewRequest("GET", task.Url, nil)
	if e != nil {
		color.Red("任务%d创建请求出错!", task.Order)
		color.HiRed("%s", e)
		waitGroup.Done()
		return
	}
	// 设定请求头
	request.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", task.RangeStart, task.RangeEnd))
	// 发送请求
	response, e := httpClient.Do(request)
	if e != nil {
		color.Red("任务%d发送下载请求出错!", task.Order)
		color.HiRed("%s", e)
		waitGroup.Done()
		return
	}
	// 结束时关闭请求体读取器
	defer func() {
		_ = response.Body.Close()
	}()
	// 读取缓冲区(64KB)
	buffer := make([]byte, 64*1024)
	// 准备写入文件
	writer := bufio.NewWriter(file)
	for {
		// 读取请求体
		// 读取一次内容至缓冲区
		readSize, readError := response.Body.Read(buffer)
		// 处理非EOF错误
		if readError != nil && readError != io.EOF {
			// 如果读取完毕则退出循环
			color.Red("任务%d读取响应错误!", task.Order)
			color.HiRed("%s", readError)
			waitGroup.Done()
			return
		}
		// 读取到内容时,写入文件
		if readSize > 0 {
			// 把缓冲区内容写入文件
			_, writeError := writer.Write(buffer[:readSize])
			if writeError != nil {
				color.Red("任务%d写入文件写入器时出现错误!", task.Order)
				color.HiRed("%s", writeError)
				waitGroup.Done()
				return
			}
			writeError = writer.Flush()
			if writeError != nil {
				color.Red("任务%d写入文件时出现错误!", task.Order)
				color.HiRed("%s", writeError)
				waitGroup.Done()
				return
			}
		}
		// 记录下载进度
		task.DownloadSize += int64(readSize)
		// 若读取到末尾则退出
		if readError == io.EOF {
			break
		}
	}
	// 标记任务完成
	task.TaskDone = true
	// 使线程组中计数器-1
	waitGroup.Done()
}

构造函数NewShardTask负责完成ShardTask的参数传入和状态量初始化,而DoShardGet方法实现了下载一个文件分片的完整步骤,从打开文件准备写入,设定写入文件的偏移量,到设定请求头,发出请求,最后读取响应体写入到文件。

可通过文件对象的Seek方法实现移动文件指针到指定位置,上述Seek(task.RangeStart, io.SeekStart)就表示将文件指针向后偏移task.RangeStart个字节,以文件开头(第1个字节,下标为0处)作为参照。移动文件指针后,后续写入操作将会从该文件指针处开始覆盖写入,每写入一个字节,文件指针也会向后移动一位。

在这里对于http.Client对象的配置,有以下要点:

  • DisableKeepAlives设为了true,即关闭keep-alive,这是因为默认情况下Go语言的HTTP客户端会复用TCP连接,即使你多个线程发起请求,也会使用一个TCP连接进行,而多线程下载需要每个线程持有一个单独的TCP连接来达到突破cwnd的限制,因此这里关闭keep-alive实现每个线程发起请求时,使用单独的TCP连接
  • Timeout设为0表示关闭超时,这是因为如果有分片读取请求体(也就是下载文件)的时间超过了超时时间,就会被强行断开导致下载失败
  • 此外,还使用http.ProxyFromEnvironment配置为代理,表示可以自动从环境变量读取代理配置

(2) ParallelGetTask - 一整个多线程下载任务

代码实现如下:

package main

import (
	"errors"
	"fmt"
	"github.com/fatih/color"
	"net/http"
	"os"
	"sync"
)

// ParallelGetTask 多线程下载任务类,存放一个多线程下载任务的参数和状态量
type ParallelGetTask struct {
	// 文件的下载链接
	Url string
	// 文件的最终保存位置
	FilePath string
	// 下载并发数
	Concurrent int
	// 下载文件的总大小
	TotalSize int64
	// 全部的下载分片任务参数列表
	ShardTaskList []*ShardTask
}

// NewParallelGetTask 构造函数
func NewParallelGetTask(url, filePath string, concurrent int) *ParallelGetTask {
	return &ParallelGetTask{
		// 参数赋值
		Url:        url,
		FilePath:   filePath,
		Concurrent: concurrent,
		// 初始化状态量
		TotalSize:     0,
		ShardTaskList: make([]*ShardTask, 0),
	}
}

// 发送HEAD请求获取待下载文件的大小
func (task *ParallelGetTask) getLength() error {
	// 发送HEAD请求
	response, e := http.Head(task.Url)
	if e != nil {
		color.Red("发送HEAD请求出错!")
		return e
	}
	// 判断状态码是否正确,不正确说明Head不被允许,切换为Get重试
	if response.StatusCode >= 300 {
		color.HiYellow("不支持HEAD请求,状态码:%d,使用GET重试...", response.StatusCode)
		response, e = http.Get(task.Url)
		if e != nil {
			color.Red("发送GET请求获取大小出错!")
			return e
		}
		// 最终直接关闭响应体,不进行读取
		defer func() {
			_ = response.Body.Close()
		}()
		// 再次检查状态码,若不正确则返回错误
		if response.StatusCode >= 300 {
			color.Red("发送GET请求获取大小出错!状态码:%d", response.StatusCode)
			return errors.New(fmt.Sprintf("状态码不正确:%d", response.StatusCode))
		}
	}
	// 检查是否支持部分请求
	if response.Header.Get("Accept-Ranges") != "bytes" {
		color.Red("该请求不支持部分获取,无法分片下载!")
		return errors.New("不支持分片获取")
	}
	// 读取并设定长度
	task.TotalSize = response.ContentLength
	if task.TotalSize <= 0 {
		color.Red("无法获取内容长度!")
		return errors.New("不能获取内容长度")
	}
	return nil
}

// 根据待下载文件的大小和设定的并发数,创建每个分片任务对象
func (task *ParallelGetTask) allocateTask() {
	// 如果并发数大于总大小,则进行调整
	if int64(task.Concurrent) > task.TotalSize {
		task.Concurrent = int(task.TotalSize)
	}
	// 开始计算每个分片的下载范围
	eachSize := task.TotalSize / int64(task.Concurrent)
	// 创建任务对象
	for i := 0; i < task.Concurrent; i++ {
		task.ShardTaskList = append(task.ShardTaskList, NewShardTask(task.Url, i+1, task.FilePath, int64(i)*eachSize, int64(i+1)*eachSize-1))
	}
	// 处理末尾部分
	if task.TotalSize%int64(task.Concurrent) != 0 {
		task.ShardTaskList[task.Concurrent-1].RangeEnd = task.TotalSize - 1
	}
}

// 创建一个与目标下载文件大小一样的空白的文件
func (task *ParallelGetTask) createFile() error {
	file, e := os.OpenFile(task.FilePath, os.O_WRONLY|os.O_CREATE, 0755)
	if e != nil {
		color.Red("创建文件出错!")
		return e
	}
	defer func() {
		_ = file.Close()
	}()
	// Truncate能够调整文件到指定大小
	// 若文件大小小于给定大小,则会使用空白字符填充并扩充至给定大小
	// 否则截断文件至给定大小,丢弃多余部分
	e = file.Truncate(task.TotalSize)
	if e != nil {
		color.Red("调整文件大小出错!")
		return e
	}
	return nil
}

// 根据任务列表进行多线程分片下载操作
func (task *ParallelGetTask) downloadShard() {
	// 创建线程组
	waitGroup := &sync.WaitGroup{}
	// 开始执行全部分片下载线程
	for _, task := range task.ShardTaskList {
		go task.DoShardGet(waitGroup)
		waitGroup.Add(1)
	}
	// 等待全部下载完成
	waitGroup.Wait()
}

// Run 开始执行整个分片多线程下载任务
func (task *ParallelGetTask) Run() error {
	// 获取文件大小
	e := task.getLength()
	if e != nil {
		color.Red("%s", e)
		return e
	}
	color.HiYellow("已获取到下载文件大小:%d字节", task.TotalSize)
	// 分配任务
	task.allocateTask()
	color.HiYellow("已完成分片任务分配,共计%d个任务", len(task.ShardTaskList))
	// 创建空白文件
	e = task.createFile()
	if e != nil {
		color.Red("%s", e)
		return e
	}
	color.HiYellow("已预创建目标下载文件!")
	// 开启进度输出
	printProcess(task)
	// 开始下载文件
	task.downloadShard()
	color.Green("\n分片下载任务完成!")
	return nil
}

可见通过构造函数NewParallelGetTask完成参数传递和状态量设定后,其它每个私有函数都对应我们多线程下载中的一个步骤,最后由公开函数Run统筹组织起所有的步骤,完成整个多线程下载任务。

在创建空白文件createFile方法中,我们使用文件对象的Truncate方法实现了创建一个指定大小的空白文件,该方法实际上是用于调整文件大小为指定值的,如果:

  • 文件大小小于给定大小,则会使用空白字符填充并扩充至给定大小
  • 文件大小大于给定大小,截断文件至给定大小,丢弃多余部分

(3) 实用函数

ParallelGetTask的成员方法Run中,有一个实用函数printProcess是用于在一个单独的线程实现进度的实时输出的,该实用函数及其它函数实现如下:

package main

import (
	"fmt"
	"math"
	"time"
)

// 关于下载进度计算和显示的实用工具函数

// 计算网络速度
// size 一段时间内下载的数据大小,单位字节
// timeElapsed 经过的时间长度,单位毫秒
// 返回计算得到的网速,会自动换算单位
func computeSpeed(size int64, timeElapsed int) string {
	bytePerSecond := size / int64(timeElapsed) * 1000
	if 0 <= bytePerSecond && bytePerSecond <= 1024 {
		return fmt.Sprintf("%d Byte/s", bytePerSecond)
	}
	if bytePerSecond > 1024 && bytePerSecond <= int64(math.Pow(1024, 2)) {
		return fmt.Sprintf("%.2f KB/s", float64(bytePerSecond)/1024)
	}
	if bytePerSecond > 1024*1024 && bytePerSecond <= int64(math.Pow(1024, 3)) {
		return fmt.Sprintf("%.2f MB/s", float64(bytePerSecond)/math.Pow(1024, 2))
	}
	return fmt.Sprintf("%.2f GB/s", float64(bytePerSecond)/math.Pow(1024, 3))
}

// 在一个新的线程中,实时在控制台输出任务下载进度
//
// task 读取并显示进度的任务对象
func printProcess(task *ParallelGetTask) {
	go func() {
		// 上一次统计时的已下载大小,用于计算速度
		var lastDownloadSize int64 = 0
		for {
			// 如果全部任务完成则结束输出,并统计并发数
			allDone := true
			// 当前并发数
			currentTaskCount := 0
			for _, shardTask := range task.ShardTaskList {
				if !shardTask.TaskDone {
					allDone = false
					currentTaskCount += 1
				}
			}
			if allDone {
				break
			}
			// 统计所有分片已下载大小之和
			var totalDownloadSize int64 = 0
			for _, shardTask := range task.ShardTaskList {
				totalDownloadSize += shardTask.DownloadSize
			}
			// 计算速度
			currentDownload := totalDownloadSize - lastDownloadSize
			lastDownloadSize = totalDownloadSize
			speedString := computeSpeed(currentDownload, 300)
			// 输出到控制台
			// \r 回到行首
			// \033[2K 为ANSI转义控制字符,用于清空当前行
			fmt.Printf("\r\033[2K当前并发数:%d 速度:%s 总进度:%.2f%%", currentTaskCount, speedString, float32(totalDownloadSize)/float32(task.TotalSize)*100)
			// 等待300ms
			time.Sleep(300 * time.Millisecond)
		}
	}()
}

3,实现效果

现在我们在main函数中创建一个ParallelGetTask对象,设定好参数后调用其Run方法即可开始多线程下载文件的任务:

package main

import (
	"gitee.com/swsk33/shard-download-demo/model"
)

func main() {
	// 创建任务
	task := model.NewParallelGetTask(
		"https://github.com/jgraph/drawio-desktop/releases/download/v24.7.17/draw.io-24.7.17-windows-installer.exe",
		"downloads/draw.io.exe",
		64,
	)
	// 执行任务
	_ = task.Run()
}

效果如下:

这里是直接下载的Github文件为例,可以来对比一下使用64线程和1个线程下载的速度差异:

可见借助HTTP请求的一些机制,我们就可以实现一个多线程下载功能了!当然这里的程序还有许多可以完善的地方,例如失败重试、断点续传等等。

代码仓库地址:传送门