1. 引子
IDM是多线程下载文件的典范,使用多线程可以抢占服务器资源,提高下载速度。
在这篇教程中,我将使用kotlin
,ktor-client
构建一个支持多线程下载且支持断点续传的文件下载器
2. 细节
2.1 多线程下载的原理?
多线程下载,就是要让整个下载任务分成若干份,让每个线程分别下载属于自己的一部分。
在这期间,程序需要随时记住自己的下载状态,方便从暂停状态中继续。
当所有部分下载完成后,下载宣告完成。
2.2 如何让线程获取特定范围内的资源?
让每个线程下载属于自己的一部分,则需要请求到这一部分的资源。在HTTP协议中,Range
头提供了这一功能。它要求提供一个闭区间数据以供返回url的特定范围的字节(下文称为切片)
需要注意的是,不是所有的服务器都是支持使用Range头获取切片的。我们可以使用
Accept-Ranges
头来询问服务器这个url是否支持Range
头。
2.3 如何将缓冲区写入文件的指定位置?
在下载到所需字节后,需要将这部分缓冲区写入到文件中。RandomAccessFile
可以做到了这点。
3. 实现
3.1 创建项目
3.2 引入依赖
让我们引入serialization
和ktor-client
plugins {
kotlin("jvm") version "1.9.23"
kotlin("plugin.serialization") version "1.9.23"
}
group = "top.kagg886.study"
version = "1.0-SNAPSHOT"
repositories {
mavenCentral()
maven("https://s01.oss.sonatype.org/content/repositories/snapshots")
}
dependencies {
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.0")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.1")
implementation("io.ktor:ktor-client-core:2.3.12")
implementation("io.ktor:ktor-client-cio:2.3.12")
testImplementation(kotlin("test"))
}
tasks.test {
useJUnitPlatform()
}
kotlin {
jvmToolchain(11)
}
3.3 创建文件架构
一个多线程下载任务至少需要记住url
,目标位置。
为了方便计算进度,我们可以在第一次加载时缓存文件大小。
至于blockCount
和bufferSize
,下文再议。
@Serializable
data class DownloadService(
val url: String,
val contentLength: Long,
val targetPath: String,
private val blockCount: Int,
private val bufferSize: Long
) {
}
我们需要存储每个线程的文件下载状态,因此让我们为DownloadService
添加一个成员变量:
private val unCompleteTaskList = mutableListOf<DownloadTask>()
对应的,
@Serializable
data class DownloadTask(
var start: Long,
var end: Long,
var downloadSize: Long = 0
)
由于RandomAccessFile的seek
方法不是线程安全的,因此我们需要一把协程锁:
private val muteX by lazy {
Mutex()
}
再添加几个成员变量,总的代码大概是这样的:
@Serializable
data class DownloadService(
val url: String,
val contentLength: Long,
val targetPath: String,
private val blockCount: Int,
private val bufferSize: Long
) {
private val file: RandomAccessFile by lazy {
RandomAccessFile(targetPath, "rwd")
}
private val muteX by lazy {
Mutex()
}
//每个线程负责的切片大小
private val blockSize by lazy {
ceil(contentLength.toFloat() / blockCount).roundToLong()
}
private val unCompleteTaskList = mutableListOf<DownloadTask>()
//已经下载的大小
val downloadSize: Long
get() = unCompleteTaskList.sumOf { it.downloadSize }
//已经下载的进度
val progress
get() = downloadSize.toFloat() / contentLength
}
3.4 初始化下载任务
DownloadService
需要依赖于ktor-client
来创建HttpClient
,因此让我们继续增加成员变量:
//是否实例化了HTTPClient
private val init
get() = client != null
//HTTP客户端
@Transient
private var client: HttpClient? = null
让我们增加一个init
函数来绑定HTTPClient
和DownloadTask
之间的关系:
趁着第一次调用init时,将任务切片做好。
fun init(client: HttpClient) {
check(!init) {
"service already init"
}
this.client = client
if (unCompleteTaskList.isEmpty()) {
unCompleteTaskList.addAll(
//Range是闭区间,step值需要比blockSize大1以避免上一个块的末尾和下一个块的头重复
(0..<contentLength step blockSize + 1).map {
val start = it
val end = min(it + blockSize, contentLength)
DownloadTask(
start = start,
end = end,
downloadSize = 0
)
}
)
}
}
任务切片的获取可以通过一个DSL获得:
class DownloadServiceConfig {
//下载地址
lateinit var url: String
//目标文件
lateinit var targetPath: File
//分块数
var blockCount: Int = 16
var bufferedSize: Long = 1024
}
在这里我们定义切片数和缓冲区大小,方便后期自定义配置。
companion object {
/**
* # 创建下载服务
* @param block 配置
*/
suspend fun create(
block: DownloadServiceConfig.() -> Unit
): DownloadService {
val config = DownloadServiceConfig().apply(block)
//创建一个简单的HTTPClient
val client = HttpClient(CIO) {
install(HttpTimeout)
}
val url = config.url
//使用Head请求查询是否支持Range头
val response = client.head(url) {
headers.append("Accept-Ranges", "acceptable-ranges")
}.call.response.headers
//见MDN
check(response["Accept-Ranges"] == "bytes") {
"$url not support multi-thread download"
}
//获取请求头长度
val contentLength = response["Content-Length"]?.toLongOrNull()
check(contentLength != null && contentLength > 0) {
"Content Length is Null"
}
return DownloadService(
url = url,
contentLength = contentLength,
targetPath = config.targetPath.absolutePath,
blockCount = config.blockCount,
bufferSize = config.bufferedSize
)
}
}
3.5 开始、暂停下载和等待下载完成
协程的launch会返回一个job,因此需要在DownloadService
中添加一个job,当文件下载时,该job代表正在执行的任务。
@Transient
private var job: Job? = null
定义一个开始下载文件的函数,暂停函数和等待函数
fun start(scope: CoroutineScope, block: HttpRequestBuilder.() -> Unit = {}) {
//检查是否初始化
check(init) {
"service not init"
}
//检查是否有别的下载任务
check(job == null) {
"download still in progress"
}
//启动一个协程
job = scope.launch {
//筛选出未下载完成的切片
unCompleteTaskList.filter { it.downloadSize <= blockSize }.map { taskConfig ->
//启动子协程,在子协程中查询文件
launch {
//使用prepareGet而不是get,避免ktor帮我们将body下载好。
val statement = client!!.prepareGet(url) {
block()
headers {
//Range头,通过这样的格式获取切片数据
append("Range", "bytes=${taskConfig.start + taskConfig.downloadSize}-${taskConfig.end}")
}
}
//开始执行请求
statement.execute {
//获取channel
val channel = it.bodyAsChannel()
while (true) {
//读取最多bufferSize的字节,若无读取内容则返回空数组。此时证明切片下载完毕。
val bytes = channel.readRemaining(bufferSize).readBytes()
if (bytes.isEmpty()) {
break
}
//RandomAccessFile的seek函数非线程安全。
muteX.withLock {
//IO操作需要转移到别的线程中运行。
withContext(Dispatchers.IO) {
file.seek(taskConfig.start + taskConfig.downloadSize)
file.write(bytes, 0, bytes.size)
}
taskConfig.downloadSize += bytes.size
}
}
}
}
}.joinAll() //等待全部任务下载完成,然后将job置为null代表下载任务结束。
job = null
}
}
suspend fun await() {
job?.join() ?: throw IllegalStateException("任务未开始")
}
fun pause() {
check(init) {
"service not init"
}
check(job != null) {
"download not started"
}
job!!.cancel()
job = null
}
3.6 完整代码
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient
import java.io.File
import java.io.RandomAccessFile
import kotlin.math.ceil
import kotlin.math.min
import kotlin.math.roundToLong
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds
@Serializable
data class DownloadService(
val url: String,
val contentLength: Long,
val targetPath: String,
private val blockCount: Int,
private val bufferSize: Long
) {
//rwd文件
private val file: RandomAccessFile by lazy {
RandomAccessFile(targetPath, "rwd")
}
//协程锁
private val muteX by lazy {
Mutex()
}
//分配的下载块大小
private val blockSize by lazy {
ceil(contentLength.toFloat() / blockCount).roundToLong()
}
//下载任务
private val unCompleteTaskList = mutableListOf<DownloadTask>()
//是否调用了init函数
private val init
get() = client != null
//HTTP客户端
@Transient
private var client: HttpClient? = null
//下载任务
@Transient
private var job: Job? = null
//已经下载的大小
val downloadSize: Long
get() = unCompleteTaskList.sumOf { it.downloadSize }
//已经下载的进度
val progress
get() = downloadSize.toFloat() / contentLength
//是否正在下载
val isDownloading
get() = job != null
fun init(client: HttpClient) {
check(!init) {
"service already init"
}
this.client = client
if (unCompleteTaskList.isEmpty()) {
unCompleteTaskList.addAll(
//Range是闭区间,step值需要比blockSize大1以避免上一个块的末尾和下一个块的头重复
(0..<contentLength step blockSize + 1).map {
val start = it
val end = min(it + blockSize, contentLength)
DownloadTask(
start = start,
end = end,
downloadSize = 0
)
}
)
}
}
fun start(scope: CoroutineScope, block: HttpRequestBuilder.() -> Unit = {}) {
//检查是否初始化
check(init) {
"service not init"
}
//检查是否有别的下载任务
check(job == null) {
"download still in progress"
}
//启动一个协程
job = scope.launch {
//筛选出未下载完成的切片
unCompleteTaskList.filter { it.downloadSize <= blockSize }.map { taskConfig ->
//启动子协程,在子协程中查询文件
launch {
//使用prepareGet而不是get,避免ktor帮我们将body下载好。
val statement = client!!.prepareGet(url) {
block()
headers {
//Range头,通过这样的格式获取切片数据
append("Range", "bytes=${taskConfig.start + taskConfig.downloadSize}-${taskConfig.end}")
}
}
//开始执行请求
statement.execute {
//获取channel
val channel = it.bodyAsChannel()
while (true) {
//读取最多bufferSize的字节,若无读取内容则返回空数组。此时证明切片下载完毕。
val bytes = channel.readRemaining(bufferSize).readBytes()
if (bytes.isEmpty()) {
break
}
//RandomAccessFile的seek函数非线程安全。
muteX.withLock {
//IO操作需要转移到别的线程中运行。
withContext(Dispatchers.IO) {
file.seek(taskConfig.start + taskConfig.downloadSize)
file.write(bytes, 0, bytes.size)
}
taskConfig.downloadSize += bytes.size
}
}
}
}
}.joinAll() //等待全部任务下载完成,然后将job置为null代表下载任务结束。
job = null
}
}
suspend fun await() {
job?.join() ?: throw IllegalStateException("任务未开始")
}
fun pause() {
check(init) {
"service not init"
}
check(job != null) {
"download not started"
}
job!!.cancel()
job = null
}
companion object {
suspend fun create(
block: DownloadServiceConfig.() -> Unit
): DownloadService {
val config = DownloadServiceConfig().apply(block)
val client = HttpClient(CIO) {
install(HttpTimeout)
}
val url = config.url
val response = client.head(url) {
headers.append("Accept-Ranges", "acceptable-ranges")
}.call.response.headers
check(response["Accept-Ranges"] == "bytes") {
"$url not support multi-thread download"
}
val contentLength = response["Content-Length"]?.toLongOrNull()
check(contentLength != null && contentLength > 0) {
"Content Length is Null"
}
return DownloadService(
url = url,
contentLength = contentLength,
targetPath = config.targetPath.absolutePath,
blockCount = config.blockCount,
bufferSize = config.bufferedSize
)
}
}
}
class DownloadServiceConfig {
//下载地址
lateinit var url: String
//目标文件
lateinit var targetPath: File
//分块数
var blockCount: Int = 16
var bufferedSize: Long = 1024
}
@Serializable
data class DownloadTask(
var start: Long,
var end: Long,
var downloadSize: Long = 0
)
3.7 实战
让我们挑一个受害者文件:
https://download.oracle.com/java/19/archive/jdk-19.0.2_linux-x64_bin.tar.gz
它的SHA-256是59f26ace2727d0e9b24fc09d5a48393c9dbaffe04c932a02938e8d6d582058c6
让我们写一个主函数下载文件:
fun main(): Unit = runBlocking {
//实例化CIO客户端
val client = HttpClient(CIO) {
install(HttpTimeout) {
requestTimeoutMillis = 30.minutes.inWholeMicroseconds
connectTimeoutMillis = 30.minutes.inWholeMicroseconds
socketTimeoutMillis = 30.minutes.inWholeMicroseconds
}
}
//创建下载任务
val service = DownloadService.create {
url = "https://download.oracle.com/java/19/archive/jdk-19.0.2_linux-x64_bin.tar.gz"
targetPath = File("jdk.tar.gz").apply {
if (exists()) {
delete()
}
//仅在第一次时初始化文件,若文件被删除可抛出异常
absoluteFile.parentFile.mkdirs()
createNewFile()
}
}
//绑定下载任务到客户端
service.init(client)
//开始下载
service.start(this)
launch {
//是否下载完成
while (service.isDownloading) {
delay(1.seconds)
with(service) {
println("download: $downloadSize --- $contentLength (${progress})")
}
}
}
runCatching {
service.await()
}.onFailure {
println("任务被暂停或遇到异常")
}
}
在这期间想要保存进度了可以使用下面的代码将当前的进度保存为json字符串:
File("progress.json").writeString(Json.encodeToString(service))
读取保存进度则可使用:
val service1 = Json.decodeFromString<DownloadService>(File("progress.json").readText())
service1.init(client)
service.start(...)
4. 总结
在这篇教程中,我们学习了如何使用ktor
获取文件切片,如何用RandomAccessFile
将数据下载到指定的位置。
而且基于kotlin对协程的强大并发特性,使得我们不需要像Java那样传一些没必要的参到子类中,让代码变的直观明白。
最后附一份完整代码的github链接,有爱自取哦: