你也可以手敲一个高速下载器(五)基础下载

296 阅读6分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第18天,点击查看活动详情


你也可以手敲一个高速下载器(五)基础下载

前言

有了前几篇的基础,我们这节来实现一下基础的下载程序,我们的目标也很简单,就是能完成下载就行,先不要完善细节部分。

初始化

我们的初始化方法定义了一些属性,先看代码吧:

class HSDownloader(object):

    def __init__(self, url: str, save_path: str, concurrency: int = 64):
        """
        :param url: 要下载的地址
        :param save_path: 保存的路径
        :param concurrency: 并发的数量
        """
        self.url = url
        self.save_path = save_path
        self.headers = {}
        self._sem = asyncio.Semaphore(concurrency)
        self._head_headers = None

        self.block_number = 64
        self.target_size = self.block_number * 10 * 1024 * 1024

这部分代码很简单,就是初始化一些变量。要注意的地方最后两行,这部分代表一个我们下载的策略,其中target_size是一个临界点大小,block_number则是一个目标的块数量,关系分块下载时会根据这个大小来决定采取什么方案去多任务下载。

获取 HEAD 请求的响应头

如题,这一步就是通过HEAD请求的响应头来获取一些元数据,其中就会使用到我们上节讲到的通用请求类Request,代码如下:

    async def get_head_headers(self):
        """
        获取HEAD请求的响应头
        :return:
        """
        req = Request('HEAD', self.url, sem=self._sem)
        resp = await req.request()
        self._head_headers = resp.headers
        await req.close()

        return self._head_headers

发生 HEAD 请求,然后保存响应头至成员变量self._head_headers

从头中获取数据

此部分数据皆是从响应头中获取的,所以直接使用了属性@property的方式

是否接受断点续传

    @property
    def accept_ranges(self):
        """
        是否接受断点续传
        :return:
        """
        return self._head_headers.get('Accept-Ranges') is not True

获取资源总大小

    @property
    def content_length(self):
        """
        获取资源总大小
        :return:
        """
        return int(self._head_headers.get('Content-Length'))

格式化输出文件大小

这里是以可读的字符串返回文件大小,会调用下面提到的方法

    @property
    def file_size(self):
        """
        格式化输出文件大小
        :return:
        """
        return format_size(self.content_length)

格式化输出文件大小

由于我们会打印文件大小的日志,所以就是根据大小来采取合适的单位去输出,比如BKBMBGB等等,而他们之间的换算都是 1024,所以代码可以这么写:

def format_size(filesize: float):
    """
    格式化输出文件大小
    :param filesize: 文件大小
    :return: 返回格式化的字符串
    """
    for count in ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB']:
        if -1024.0 < filesize < 1024.0:
            return f"{filesize:3.1f} {count}"
        filesize /= 1024.0
    return f"{filesize:3.1f} YB"

获取内容

这里就是发生 GET 请求获取实际内容的地方,暂时会接受三个参数,分别的块的索引、块的开始字节、块的结束字节,在请求的过程中会捕获请求错误的异常,以及进行判断是否使用了分块传输,具体代码如下:

    async def get_content(self, index, start=None, end=None):
        """
        获取内容
        :param index: 索引
        :param start: 开始范围
        :param end: 结束范围
        :return: 返回内容的数据
        """
        logger.info(f"开始下载:bytes={start}-{end}")

        if end:
            headers = dict(self.headers, Range=f"bytes={start}-{end}")
        else:
            headers = self.headers.copy()
        req = Request("GET", self.url, sem=self._sem, headers=headers)
        try:
            resp = await req.request()
        except RequestException:
            return self.network_error_exit()
        finally:
            await req.close()

        return {'index': index, 'content': resp.content}

异常处理

我们在通用请求类里面抛出了异常,然后在请求类里面捕获了异常,并且在方法network_error_exit中进行处理了,目前为止这个方法很简单只是输出日志和退出程序。

    def network_error_exit(self):
        """
        发生网络错误异常
        :return:
        """
        logger.critical(f"网络请求发生错误,下载失败!")
        sys.exit(-1)

开始下载

这是我们的核心方法,用来组拼各个任务的参数,也就是每个任务的开始和结束字节部分,在这里面我们分成了两种方案,一种是资源总大小小于等于 64 _ 10 = 640 MB 的时候,采取的方案为任务数量固定为 64 个,大小平均分,另一种是资源总大小大于 64 _ 10 = 640 MB 的时候,采取的方案为每个任务下载的大小固定为 10M,任务数量动态计算。

 async def start_download(self):
        """
        开始下载
        :return:
        """
        # 第一种下载方案:资源总大小小于等于 64 * 10 = 640 MB: 任务数量固定为64个,大小平均分
        # 第二种下载方案:资源总大小大于 64 * 10 = 640 MB: 每个任务下载的大小固定为10M,任务数量动态计算
        if self.content_length <= self.target_size:
            block_size = int(self.content_length / self.block_number)
            block_number = self.block_number
        else:
            block_size = 10 * 1024 * 1024
            block_number = math.ceil(self.content_length / block_size)

        # 计算每个任务的参数,存到字典中
        args_dict = dict()
        for index in range(block_number):
            s = index * block_size if index == 0 else args_dict.get(index - 1).get('e') + 1
            e = s + block_size if index < block_number - 1 else self.content_length
            args_dict[index] = {"s": s, "e": e}

        # 转换成元组列表
        args = [
            (k, v['s'], v['e'])
            for k, v in args_dict.items()
        ]

        # 开启多任务
        tasks = itertools.starmap(self.get_content, args)
        result = await asyncio.gather(*tasks)
        return sorted(result, key=itemgetter('index'))

前两段只是计算任务块数量及任务块大小的代码,最后则是组拼参数并发下载的关键。前文提到过 python 如何使用异步进行并发:Python 并行编程实践(下),里面说过如何使用asyncio.gather进行并发,但是里面用的方法是直接传递函数的方法,在我们真实项目中,大部分都是动态计算出来的,所以这里可以使用itertools.starmap的方法组拼函数和参数即可,最后在按照索引进行排序,使各个块之间顺序一致

文件保存

这里的保存文件使用的是aiofiles的库,为啥不用自带的文件保存呢?因为要配合异步来使用,接收的参数只要两个,一个是文件路径,一个是文件内容,具体代码如下:

    async def save_file(self, filepath, content):
        """
        保存文件
        :param filepath: 文件路径
        :param content: 内容
        :return:
        """
        async with aiofiles.open(filepath, mode='wb') as f:
            await f.write(content)

下载入口

此方法是正式下载的入口,会调用前面提到的方法,发起HEAD请求获取基本信息、打印日志、开始下载、计算时间、计算平均下载速度,保存文件等,代码如下:

    async def start(self):
        """
        开始下载
        :return:
        """
        await self.get_head_headers()

        s_time = time.perf_counter()
        logger.success(f"开始下载资源,总大小:{self.file_size}")

        # 开始下载任务
        result = await self.start_download()

        e_time = time.perf_counter()

        # 计算平均下载速度
        average_speed = self.content_length / (e_time - s_time)
        average_speed = format_size(average_speed)

        logger.success(f"下载成功, 总用时:{e_time - s_time:.3f}s 平均速度:{average_speed}/s")

        # 保存文件
        await self.save_file(self.save_path, b''.join([d['content'] for d in result]))

代码仓库:

本节的代码以上传至 Github,请自行下载及观看:第五节代码

结语

这一节完成了高速下载程序的基本实现,但还有很多细节没有实现,后面我们在依次,下节更精彩,敬请期待!!!