Java 多线程+断点续传模式 下载网络资源

1,801 阅读6分钟

前言

最开始准备研究这个东西,就是为了用 Java 去爬番剧的下载地址,然后,缓存到本地有空了慢慢刷(半个T的小洋房这不得吧小姐姐们都安排进去嘛,空着多不好丫)

同时,考虑到一些常见的问题,比如突然断电断网什么,下到差不多的文件直接报废又得从头开始,就一定需要用上 断点续传 的功能

最开始的一个设计思路,就是先将文件分片下载到本地,存储为类似 1.temp, 2.temp, 3.temp ... 的形式,都下载好了之后,最后再将文件合并成一个我们需要的 xx.mp4

在开发过程中,查阅不少资料,发现有提供文件数据偏移坐标读取和写入的功能,那思路就调整为先生成一个一样大小的文件,然后多个线程在这个文件上划区域进行 数据更新

知识点

java.net.URL

通过 new URL(String) 的形式,传入一个字符串格式的 http 资源链接,构造出一个 URL 对象

示例

URL resource = new URL("https://www.wahaotu.com/uploads/allimg/201904/1555074510295049.jpg");

HttpURLConnection

用于链接 URL 对应的资源

主要用到的几个点:

打开资源链接

HttpURLConnection conn = (HttpURLConnection) url.openConnection()

打开链接后,默认获取的对象是 URLConnection,我们需要强转为更加具体的子类 HttpURLConnection

设置请求头

conn.setRequestProperty(key, value);

我们是通过 URL 来模拟发送 http 请求,为了让我们的请求显得更加自然,需要加上 User-Agent

因为我们需要分段去向服务器请求资源,需要添加请求头 Rangevalue 的格式为 bytes=start-end,其中,这个数据的范围可以理解为 [start, end],我们需要自己提前切割好需要分段请求的资源的起始结束位。这一步需要在获取输入流之前执行

获取资源输入流

conn.getInputStream()

当我们的准备工作做好之后,就可以获取到输入流,然后可以根据标准的字节 IO 流来保存文件

关闭链接

conn.disconnect()

同样的是资源流,需要我们手动去关闭避免资源浪费,也可以使用 JDK7 的try resource 特性来自动关闭,当然,也别忘了关闭文件的输出流

RandomAccessFile

这是一个随机读写的文件类,网上的解释是这是对 InputStream, OutputStream 两个类的封装,实际的使用也是如此,该类提供有 read, write 等相关操作的方法

重点:这个类支持随机读写!!!即我们可以指定从从文件中的某一个坐标开始进行读或写操作!!!

构造器

new RandomAccessFile(String fileName, String mode)
new RandomAccessFile(File file, String mode)

我们传入文件名后,内部会帮忙转为 File 然后调用第二个构造器

mode

算是小小的缺陷,没有使用枚举类。阅读源码,内部使用了字符串的比较,支持 4 种值

  1. "r",只读
  2. "rw",可读可写
  3. "rws",支持读写,并要求对文件内容或元数据的每次更新都同步写入底层存储设备
  4. "rwd",支持读写,并要求对文件内容的每次更新都同步写入底层存储设备

设置文件的偏移指针

重头戏,告诉 JVM 偏移多少后对文件读写操作,提供了有 2 个方法

seek(long pos)

设置文件指针的偏移量,从该文件开始计算,在此位置发生下一次读或写操作。偏移量可能超出了文件的末端。设置超出文件结尾的偏移量不会改变文件长度。只有当偏移量超出了文件的末尾时,文件长度才会发生变化。

skipBytes(int n)

尝试跳过n个字节的输入,丢弃跳过的字节。方法内部最后还是调用了 seek()

快速生成指定大小的空白文件

setLength(fileLength)

功能规划

  1. 根据 URL 来截取文件名,或者手动指定文件名
  2. 生成本地目标文件(有则跳过)
  3. 生成本地进度记录文件
    1. 如果文件存在,则获取到已经下载好的索引数据
  4. 文件资源切片,生成索引,去掉已经下载好的索引
  5. 启用线程池,将待下载的索引分配给线程池中的线程

成品代码

package com.wb.down;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;

// v3 多线程,文件分片复制,文件来源修改为网络来源
@SuppressWarnings("resource")
public class Down {
    // 配置线程池
    int threadSize = 16;
    private ExecutorService threads = Executors.newFixedThreadPool(threadSize);
    CountDownLatch latch = new CountDownLatch(0);
    
    static final String FILE_ACCESS_MODE = "rwd";
    
    String source; // 源,http 链接
    String dir = "E:/down/"; // 本地文件下载的路径
    String fileName; // 待下载的文件名
    String tempFileName; // 本地记录进度的 temp 文件
    
    static final int LEN = 1024 * 1024 * 6; // M, 文件切片大小
    
    private Set<Integer> used = new HashSet<>(); // 已被使用
    private Set<Integer> todo = new HashSet<>(); // 待做任务
    
    private Map<Integer, String> ranges = new HashMap<>(); // 切片数据,给 URL 分片去拉数据
    
    // 通过 URL 获取链接资源
    private HttpURLConnection getConn() throws Exception {
        URL url = new URL(source);
        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        return conn;
    }
    
    private static String getFileNameFromPath(String path) {
        String[] dirs = path.split("/");
        return dirs[dirs.length - 1];
    }
    
    private String getLocalPath() {
        return dir + fileName;
    }
    
    public Down(String source) {
        this(source, getFileNameFromPath(source));
    }
    
    public Down(String source, String fileName) {
        this.source = source;
        this.fileName = fileName;
        this.tempFileName = getLocalPath() + ".temp"; // 缓存文件,进度记录
        init();
    }
    
    // 判断文件是否成功下载
    private boolean isPresent() {
        if (new File(getLocalPath()).exists() && !new File(tempFileName).exists()) {
            return true;
        }
        return false;
    }
    
    // 初始化操作
    private void init() {
        if (isPresent()) {
            System.out.printf("[%s] 文件已经成功下载", fileName);
            return;
        }
        try {
            System.out.println("===> 创建本地文件");
            createLocalFileIfNotExist();
            System.out.println("===> 创建进度记录文件");
            processProgressFile();
            System.out.println("===> 文件切片处理");
            createDownIndexBySplit();
            System.out.println("===> 初始化结束");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    
    //	生成本地目标文件,如果不存在
    private void createLocalFileIfNotExist() throws Exception {
        File file = new File(getLocalPath());
        if (!file.getParentFile().exists()) {
            file.getParentFile().mkdirs();
        }
        if (!file.exists()) {
            RandomAccessFile accessFile = new RandomAccessFile(file, FILE_ACCESS_MODE);
            accessFile.setLength(getConn().getContentLengthLong());
        }
    }
    
    //	处理本地进度记录文件
    private void processProgressFile() throws IOException {
        File temp = new File(tempFileName);
        if (!temp.exists()) { // 没有就创建
            temp.createNewFile();
        } else { // 存在则更新已下载的索引数据
            BufferedReader bufferedReader = new BufferedReader(new FileReader(temp));
            String str = bufferedReader.readLine();
            if (str == null)
                return;
            bufferedReader.close();
            for (String s : str.split(",")) {
                used.add(Integer.valueOf(s));
            }
        }
    }
    
    // 切割文件拿到索引
    private void createDownIndexBySplit() throws Exception {
        int fileLen = getConn().getContentLength();
        // [0, 9] [10, 19]
        for (int i = (int) (fileLen / LEN); i >= 0; i--) {
            ranges.put(i, "bytes=" + i * LEN + "-" + Math.min(fileLen + 1, (i + 1) * LEN));
            // System.out.println(ranges.get(i));
        }
        todo.addAll(ranges.keySet());
        // 去掉已经下过的索引
        todo.removeAll(used);
        latch = new CountDownLatch(todo.size());
    }
    
    // 以多线程的方式,来写文件
    public void down() {
        todo.stream().forEach(i -> threads.execute(new DownThread(i)));
        threads.shutdown();
        try {
            latch.await();
            new File(tempFileName).deleteOnExit();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
    
    class DownThread implements Runnable {
        Integer index;
        byte[] bs = new byte[1024 * 128];
        
        DownThread(Integer index) {
            this.index = index;
        }
        
        @Override
        public void run() {
            try {
                HttpURLConnection conn = getConn();
                // 设置切片文件位置
                conn.setRequestProperty("Range", ranges.get(index));
                // 让当前请求自然一些,防止被 403
                conn.setRequestProperty("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0");
                RandomAccessFile fos = new RandomAccessFile(getLocalPath(), FILE_ACCESS_MODE);
                // 读、写文件同步偏移
                fos.seek(LEN * index);
                // 写操作
                InputStream is = conn.getInputStream();
                int read;
                while ((read = is.read(bs)) != -1) {
                    fos.write(bs, 0, read);
                }
                fos.close();
                conn.disconnect();
                synchronized (Down.class) {
                    latch.countDown();
                    // 更新索引
                    used.add(index);
                    System.out.printf("当前文件:[%s], 下载片段:[%d], 进度:[%d %%] \n", fileName, index, (int) (used.size() * 100 / ranges.keySet().size()));
                    try {
                        // 更新到文件
                        String memo = used.stream().map(n -> n.toString()).collect(Collectors.joining(","));
                        Files.write(Paths.get(tempFileName), memo.getBytes());
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
    
    public static void main(String[] args) throws IOException {
        Down down1 = new Down("https://publish.u-tools.cn/version2/uTools-2.1.0.exe");
        down1.down();
        Runtime.getRuntime().exec("cmd /c start " + down1.dir);
    }
}

下载效果

很明显,带宽直接拉满

图片.png

总结

有时候,我们下载过慢,是因为服务端会故意限制我们的下载速度,具体表现为对单个连接,做下载速度检测,发现超过他们配置的最大速度就故意空闲一定时间再允许我们继续下载其余片段(这个结论是我的猜测,来源于我找的这一份资料 HttpURLConnection下载限速的方法

当然,如果你自己的下载速度本来就慢,那就怪不得别人的服务器了