使用Boost.Asio实现多线程断点续传

239 阅读2分钟
#include <iostream>
#include <fstream>
#include <vector>
#include <thread>
#include <boost/asio.hpp>
#include <boost/array.hpp>
#include <boost/asio/ssl.hpp>

using boost::asio::ip::tcp;
namespace ssl = boost::asio::ssl;

// 下载线程的参数结构体
struct DownloadThreadParams {
    std::string url; // 固件下载链接
    std::string filename; // 固件保存文件名
    uint64_t start; // 下载起始字节
    uint64_t end; // 下载结束字节
    bool finished; // 下载是否完成

    DownloadThreadParams(const std::string& url, const std::string& filename,
        uint64_t start, uint64_t end)
        : url(url), filename(filename), start(start), end(end), finished(false)
    {}
};

// 下载线程函数
void downloadThread(const DownloadThreadParams& params, ssl::context& ctx)
{
    std::ofstream file(params.filename, std::ios::binary | std::ios::app);
    if (!file.is_open()) {
        std::cerr << "Failed to open file: " << params.filename << std::endl;
        return;
    }

    boost::asio::io_context io_context;
    ssl::stream<tcp::socket> socket(io_context, ctx);

    tcp::resolver resolver(io_context);
    auto endpoints = resolver.resolve(params.url, "https");

    boost::asio::connect(socket.next_layer(), endpoints);
    socket.handshake(ssl::stream_base::client);

    std::string request = "GET " + params.url + " HTTP/1.1\r\n" +
                          "Host: " + params.url + "\r\n" +
                          "Range: bytes=" + std::to_string(params.start) + "-" + std::to_string(params.end) + "\r\n" +
                          "Connection: close\r\n\r\n";

    boost::asio::write(socket, boost::asio::buffer(request));

    boost::array<char, 8192> buffer;
    boost::system::error_code error;

    while (boost::asio::read(socket, boost::asio::buffer(buffer), error)) {
        file.write(buffer.data(), buffer.size());

        if (error) {
            break;
        }
    }

    if (error != boost::asio::error::eof) {
        std::cerr << "Error during download: " << error.message() << std::endl;
    } else {
        params.finished = true;
    }

    file.close();
}

int main()
{
    std::string url = "https://www.example.com/firmware.bin";
    std::string filename = "firmware.bin";
    uint64_t fileSize = 0; // 固件文件大小

    // 创建SSL上下文
    boost::asio::io_context io_context;
    ssl::context ctx(ssl::context::tlsv12_client);

    // 分段下载
    uint64_t segmentSize = fileSize / 4; // 每个线程负责下载固件的1/4部分

    std::vector<std::thread> threads;
    std::vector<DownloadThreadParams> threadParams;

    for (int i = 0; i < 4; i++) {
        uint64_t start = i * segmentSize;
        uint64_t end = (i + 1) * segmentSize - 1;

        // 最后一个线程下载剩余的字节
        if (i == 3) {
            end = fileSize - 1;
        }

        DownloadThreadParams params(url, filename, start, end);
        threadParams.push_back(params);

        threads.emplace_back(downloadThread, std::ref(threadParams[i]), std::ref(ctx));
    }

    // 等待所有下载线程完成
    for (auto& thread : threads) {
        thread.join();
    }

    // 检查是否所有线程下载完成
    for (const auto& params : threadParams) {
        if (!params.finished) {
            std::cout << "Download incomplete." << std::endl;
            return 1;
        }
    }

    std::cout << "Download complete." << std::endl;

    return 0;
}

在主函数中,我们创建了一个ssl::context对象ctx,该对象用于SSL连接的上下文设置。然后,我们根据固件文件大小,计算要下载的每个线程的字节范围,创建相应的DownloadThreadParams对象,并将其存储在threadParams向量中。

然后,我们使用threads向量创建多个下载线程,每个线程调用downloadThread函数来执行下载操作。传递给downloadThread函数的参数包括DownloadThreadParams对象和SSL上下文ctx。每个线程使用SSL连接进行安全下载。

最后,我们等待所有下载线程完成,并检查每个线程的finished标志,以确保所有线程都成功完成了下载。

请注意,此示例假设HTTPS服务器使用TLSv1.2协议进行通信。您可能需要根据实际情况更改SSL上下文的设置,例如选择不同的协议版本或加载证书。