[Kotlin] 利用协程进行多线程文件下载

736 阅读5分钟

1. 引子

IDM是多线程下载文件的典范,使用多线程可以抢占服务器资源,提高下载速度。

在这篇教程中,我将使用kotlin,ktor-client构建一个支持多线程下载且支持断点续传的文件下载器

2. 细节

2.1 多线程下载的原理?

多线程下载,就是要让整个下载任务分成若干份,让每个线程分别下载属于自己的一部分。

在这期间,程序需要随时记住自己的下载状态,方便从暂停状态中继续。

当所有部分下载完成后,下载宣告完成。

2.2 如何让线程获取特定范围内的资源?

让每个线程下载属于自己的一部分,则需要请求到这一部分的资源。在HTTP协议中,Range头提供了这一功能。它要求提供一个闭区间数据以供返回url的特定范围的字节(下文称为切片)

需要注意的是,不是所有的服务器都是支持使用Range头获取切片的。我们可以使用Accept-Ranges头来询问服务器这个url是否支持Range头。

2.3 如何将缓冲区写入文件的指定位置?

在下载到所需字节后,需要将这部分缓冲区写入到文件中。RandomAccessFile可以做到了这点。

3. 实现

3.1 创建项目

3.2 引入依赖

让我们引入serializationktor-client

plugins {
    kotlin("jvm") version "1.9.23"
    kotlin("plugin.serialization") version "1.9.23"
}
​
group = "top.kagg886.study"
version = "1.0-SNAPSHOT"
​
repositories {
    mavenCentral()
    maven("https://s01.oss.sonatype.org/content/repositories/snapshots")
}
​
dependencies {
    implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.0")
    implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.1")
    implementation("io.ktor:ktor-client-core:2.3.12")
    implementation("io.ktor:ktor-client-cio:2.3.12")
    testImplementation(kotlin("test"))
}
​
tasks.test {
    useJUnitPlatform()
}
kotlin {
    jvmToolchain(11)
}

3.3 创建文件架构

一个多线程下载任务至少需要记住url,目标位置。

为了方便计算进度,我们可以在第一次加载时缓存文件大小。

至于blockCountbufferSize,下文再议。

@Serializable
data class DownloadService(
    val url: String,
    val contentLength: Long,
    val targetPath: String,
    private val blockCount: Int,
    private val bufferSize: Long
) {
    
}

我们需要存储每个线程的文件下载状态,因此让我们为DownloadService添加一个成员变量:

private val unCompleteTaskList = mutableListOf<DownloadTask>()

对应的,

@Serializable
data class DownloadTask(
    var start: Long,
    var end: Long,
    var downloadSize: Long = 0
)

由于RandomAccessFile的seek方法不是线程安全的,因此我们需要一把协程锁:

private val muteX by lazy {
    Mutex()
}

再添加几个成员变量,总的代码大概是这样的:

@Serializable
data class DownloadService(
    val url: String,
    val contentLength: Long,
    val targetPath: String,
    private val blockCount: Int,
    private val bufferSize: Long
) {
    private val file: RandomAccessFile by lazy {
        RandomAccessFile(targetPath, "rwd")
    }
​
    private val muteX by lazy {
        Mutex()
    }
​
    //每个线程负责的切片大小
    private val blockSize by lazy {
        ceil(contentLength.toFloat() / blockCount).roundToLong()
    }
​
    private val unCompleteTaskList = mutableListOf<DownloadTask>()
​
    //已经下载的大小
    val downloadSize: Long
        get() = unCompleteTaskList.sumOf { it.downloadSize }
​
    //已经下载的进度
    val progress
        get() = downloadSize.toFloat() / contentLength
}
​

3.4 初始化下载任务

DownloadService需要依赖于ktor-client来创建HttpClient,因此让我们继续增加成员变量:

//是否实例化了HTTPClient
private val init
    get() = client != null//HTTP客户端
@Transient
private var client: HttpClient? = null

让我们增加一个init函数来绑定HTTPClientDownloadTask之间的关系:

趁着第一次调用init时,将任务切片做好。

fun init(client: HttpClient) {
    check(!init) {
        "service already init"
    }
    this.client = client
    if (unCompleteTaskList.isEmpty()) {
        unCompleteTaskList.addAll(
            //Range是闭区间,step值需要比blockSize大1以避免上一个块的末尾和下一个块的头重复
            (0..<contentLength step blockSize + 1).map {
                val start = it
                val end = min(it + blockSize, contentLength)
                DownloadTask(
                    start = start,
                    end = end,
                    downloadSize = 0
                )
            }
        )
    }
}

任务切片的获取可以通过一个DSL获得:

class DownloadServiceConfig {
    //下载地址
    lateinit var url: String
    //目标文件
    lateinit var targetPath: File
    //分块数
    var blockCount: Int = 16
    var bufferedSize: Long = 1024
}

在这里我们定义切片数和缓冲区大小,方便后期自定义配置。

companion object {
    /**
     * # 创建下载服务
     * @param block 配置
     */
    suspend fun create(
        block: DownloadServiceConfig.() -> Unit
    ): DownloadService {
        val config = DownloadServiceConfig().apply(block)
        //创建一个简单的HTTPClient
        val client = HttpClient(CIO) {
            install(HttpTimeout)
        }
        val url = config.url
        //使用Head请求查询是否支持Range头
        val response = client.head(url) {
            headers.append("Accept-Ranges", "acceptable-ranges")
        }.call.response.headers
        //见MDN
        check(response["Accept-Ranges"] == "bytes") {
            "$url not support multi-thread download"
        }
        //获取请求头长度
        val contentLength = response["Content-Length"]?.toLongOrNull()
        check(contentLength != null && contentLength > 0) {
            "Content Length is Null"
        }
        return DownloadService(
            url = url,
            contentLength = contentLength,
            targetPath = config.targetPath.absolutePath,
            blockCount = config.blockCount,
            bufferSize = config.bufferedSize
        )
    }
}

3.5 开始、暂停下载和等待下载完成

协程的launch会返回一个job,因此需要在DownloadService中添加一个job,当文件下载时,该job代表正在执行的任务。

@Transient
private var job: Job? = null

定义一个开始下载文件的函数,暂停函数和等待函数

fun start(scope: CoroutineScope, block: HttpRequestBuilder.() -> Unit = {}) {
    //检查是否初始化
    check(init) {
        "service not init"
    }
    //检查是否有别的下载任务
    check(job == null) {
        "download still in progress"
    }
    //启动一个协程
    job = scope.launch {
        //筛选出未下载完成的切片
        unCompleteTaskList.filter { it.downloadSize <= blockSize }.map { taskConfig ->
            //启动子协程,在子协程中查询文件
            launch {
                //使用prepareGet而不是get,避免ktor帮我们将body下载好。
                val statement = client!!.prepareGet(url) {
                    block()
                    headers {
                        //Range头,通过这样的格式获取切片数据
                        append("Range", "bytes=${taskConfig.start + taskConfig.downloadSize}-${taskConfig.end}")
                    }
                }
                //开始执行请求
                statement.execute {
                    //获取channel
                    val channel = it.bodyAsChannel()
​
                    while (true) {
                        //读取最多bufferSize的字节,若无读取内容则返回空数组。此时证明切片下载完毕。
                        val bytes = channel.readRemaining(bufferSize).readBytes()
                        if (bytes.isEmpty()) {
                            break
                        }
                        //RandomAccessFile的seek函数非线程安全。
                        muteX.withLock {
                            //IO操作需要转移到别的线程中运行。
                            withContext(Dispatchers.IO) {
                                file.seek(taskConfig.start + taskConfig.downloadSize)
                                file.write(bytes, 0, bytes.size)
                            }
                            taskConfig.downloadSize += bytes.size
                        }
                    }
                }
            }
        }.joinAll() //等待全部任务下载完成,然后将job置为null代表下载任务结束。
        job = null
    }
}
​
suspend fun await() {
    job?.join() ?: throw IllegalStateException("任务未开始")
}
​
fun pause() {
    check(init) {
        "service not init"
    }
    check(job != null) {
        "download not started"
    }
    job!!.cancel()
    job = null
}
​

3.6 完整代码

import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient
import java.io.File
import java.io.RandomAccessFile
import kotlin.math.ceil
import kotlin.math.min
import kotlin.math.roundToLong
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds
​
@Serializable
data class DownloadService(
    val url: String,
    val contentLength: Long,
    val targetPath: String,
    private val blockCount: Int,
    private val bufferSize: Long
) {
    //rwd文件
    private val file: RandomAccessFile by lazy {
        RandomAccessFile(targetPath, "rwd")
    }
​
    //协程锁
    private val muteX by lazy {
        Mutex()
    }
​
    //分配的下载块大小
    private val blockSize by lazy {
        ceil(contentLength.toFloat() / blockCount).roundToLong()
    }
​
    //下载任务
    private val unCompleteTaskList = mutableListOf<DownloadTask>()
​
    //是否调用了init函数
    private val init
        get() = client != null
​
    //HTTP客户端
    @Transient
    private var client: HttpClient? = null
​
    //下载任务
    @Transient
    private var job: Job? = null
​
    //已经下载的大小
    val downloadSize: Long
        get() = unCompleteTaskList.sumOf { it.downloadSize }
​
    //已经下载的进度
    val progress
        get() = downloadSize.toFloat() / contentLength
​
    //是否正在下载
    val isDownloading
        get() = job != null
​
    fun init(client: HttpClient) {
        check(!init) {
            "service already init"
        }
        this.client = client
        if (unCompleteTaskList.isEmpty()) {
            unCompleteTaskList.addAll(
                //Range是闭区间,step值需要比blockSize大1以避免上一个块的末尾和下一个块的头重复
                (0..<contentLength step blockSize + 1).map {
                    val start = it
                    val end = min(it + blockSize, contentLength)
                    DownloadTask(
                        start = start,
                        end = end,
                        downloadSize = 0
                    )
                }
            )
        }
    }
​
    fun start(scope: CoroutineScope, block: HttpRequestBuilder.() -> Unit = {}) {
        //检查是否初始化
        check(init) {
            "service not init"
        }
        //检查是否有别的下载任务
        check(job == null) {
            "download still in progress"
        }
        //启动一个协程
        job = scope.launch {
            //筛选出未下载完成的切片
            unCompleteTaskList.filter { it.downloadSize <= blockSize }.map { taskConfig ->
                //启动子协程,在子协程中查询文件
                launch {
                    //使用prepareGet而不是get,避免ktor帮我们将body下载好。
                    val statement = client!!.prepareGet(url) {
                        block()
                        headers {
                            //Range头,通过这样的格式获取切片数据
                            append("Range", "bytes=${taskConfig.start + taskConfig.downloadSize}-${taskConfig.end}")
                        }
                    }
                    //开始执行请求
                    statement.execute {
                        //获取channel
                        val channel = it.bodyAsChannel()
​
                        while (true) {
                            //读取最多bufferSize的字节,若无读取内容则返回空数组。此时证明切片下载完毕。
                            val bytes = channel.readRemaining(bufferSize).readBytes()
                            if (bytes.isEmpty()) {
                                break
                            }
                            //RandomAccessFile的seek函数非线程安全。
                            muteX.withLock {
                                //IO操作需要转移到别的线程中运行。
                                withContext(Dispatchers.IO) {
                                    file.seek(taskConfig.start + taskConfig.downloadSize)
                                    file.write(bytes, 0, bytes.size)
                                }
                                taskConfig.downloadSize += bytes.size
                            }
                        }
                    }
                }
            }.joinAll() //等待全部任务下载完成,然后将job置为null代表下载任务结束。
            job = null
        }
    }
​
    suspend fun await() {
        job?.join() ?: throw IllegalStateException("任务未开始")
    }
​
    fun pause() {
        check(init) {
            "service not init"
        }
        check(job != null) {
            "download not started"
        }
        job!!.cancel()
        job = null
    }
​
    companion object {
        suspend fun create(
            block: DownloadServiceConfig.() -> Unit
        ): DownloadService {
            val config = DownloadServiceConfig().apply(block)
​
            val client = HttpClient(CIO) {
                install(HttpTimeout)
            }
            val url = config.url
            val response = client.head(url) {
                headers.append("Accept-Ranges", "acceptable-ranges")
            }.call.response.headers
​
            check(response["Accept-Ranges"] == "bytes") {
                "$url not support multi-thread download"
            }
            val contentLength = response["Content-Length"]?.toLongOrNull()
            check(contentLength != null && contentLength > 0) {
                "Content Length is Null"
            }
            return DownloadService(
                url = url,
                contentLength = contentLength,
                targetPath = config.targetPath.absolutePath,
                blockCount = config.blockCount,
                bufferSize = config.bufferedSize
            )
        }
    }
}
​
class DownloadServiceConfig {
    //下载地址
    lateinit var url: String
    //目标文件
    lateinit var targetPath: File
    //分块数
    var blockCount: Int = 16
    var bufferedSize: Long = 1024
}
​
@Serializable
data class DownloadTask(
    var start: Long,
    var end: Long,
    var downloadSize: Long = 0
)
​

3.7 实战

让我们挑一个受害者文件:

https://download.oracle.com/java/19/archive/jdk-19.0.2_linux-x64_bin.tar.gz

它的SHA-256是59f26ace2727d0e9b24fc09d5a48393c9dbaffe04c932a02938e8d6d582058c6

让我们写一个主函数下载文件:

fun main(): Unit = runBlocking {
    //实例化CIO客户端
    val client = HttpClient(CIO) {
        install(HttpTimeout) {
            requestTimeoutMillis = 30.minutes.inWholeMicroseconds
            connectTimeoutMillis = 30.minutes.inWholeMicroseconds
            socketTimeoutMillis = 30.minutes.inWholeMicroseconds
        }
    }
​
    //创建下载任务
    val service = DownloadService.create {
        url = "https://download.oracle.com/java/19/archive/jdk-19.0.2_linux-x64_bin.tar.gz"
        targetPath = File("jdk.tar.gz").apply {
            if (exists()) {
                delete()
            }
            //仅在第一次时初始化文件,若文件被删除可抛出异常
            absoluteFile.parentFile.mkdirs()
            createNewFile()
        }
    }
​
    //绑定下载任务到客户端
    service.init(client)
    //开始下载
    service.start(this)
​
    launch {
        //是否下载完成
        while (service.isDownloading) {
            delay(1.seconds)
            with(service) {
                println("download: $downloadSize --- $contentLength (${progress})")
            }
        }
    }
    runCatching {
        service.await()
    }.onFailure {
        println("任务被暂停或遇到异常")
    }
}

在这期间想要保存进度了可以使用下面的代码将当前的进度保存为json字符串:

File("progress.json").writeString(Json.encodeToString(service))

读取保存进度则可使用:

val service1 = Json.decodeFromString<DownloadService>(File("progress.json").readText())
service1.init(client)
service.start(...)

4. 总结

在这篇教程中,我们学习了如何使用ktor获取文件切片,如何用RandomAccessFile将数据下载到指定的位置。

而且基于kotlin对协程的强大并发特性,使得我们不需要像Java那样传一些没必要的参到子类中,让代码变的直观明白。

最后附一份完整代码的github链接,有爱自取哦:

kagg886/ktor-multithread-download-demo (github.com)