思路
- 根据文件参数(MD5,大小,名称)创建一个上传任务,分配一个ID
- 循环上传每一个分片,保存成临时文件
- 合并分片文件,校验MD5值,最后清理临时文件
服务端代码
-
Controller@RestController public class FileController { @Autowired private FileUploadService fileUploadService; @PostMapping("/createTask") public TaskModel createTask(CreateTaskDTO dto) { return fileUploadService.createUploadTask(dto.getFileName(), dto.getFileMd5(), dto.getFileSize()); } @PostMapping("/uploadFile") public String uploadFile(UploadFileDTO dto) { fileUploadService.uploadFile(dto.getTaskId(), dto.getMd5(), dto.getNo(), dto.getFile()); return "ok"; } @PostMapping("/mergeFile") public String mergeFile(UploadFileDTO dto) { fileUploadService.mergeFile(dto.getTaskId()); return "ok"; } } -
Serviceimport org.springframework.stereotype.Service; import org.springframework.util.DigestUtils; import org.springframework.util.FileCopyUtils; import org.springframework.web.multipart.MultipartFile; import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.Map; import java.util.Objects; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @Service public class FileUploadService { private static final Map<String, TaskModel> fileMap = new ConcurrentHashMap<>(); private final String tmpPath = "fileTmp/"; public TaskModel createUploadTask(String fileName, String fileMd5, Long fileSize) { createTmpPath(tmpPath); TaskModel taskModel = new TaskModel(); long sliceNum = fileSize / taskModel.getSize(); if (fileSize % taskModel.getSize() != 0) { sliceNum += 1; } taskModel.setFileMd5(fileMd5); taskModel.setId(UUID.randomUUID().toString()); taskModel.setSliceNum(sliceNum); taskModel.setFileName(fileName); taskModel.setFileSize(fileSize); fileMap.put(taskModel.getId(), taskModel); return taskModel; } public void uploadFile(String taskId, String md5, Integer no, MultipartFile file) { String tmpPath0 = tmpPath + taskId + "/"; createTmpPath(tmpPath0); String tmpFileName = tmpPath0 + no + ".tmp"; byte[] bytes = new byte[0]; try { bytes = file.getBytes(); String md5Value = DigestUtils.md5DigestAsHex(file.getBytes()); if (!md5.equals(md5Value)) throw new OpenApiException("MD5校验失败"); FileCopyUtils.copy(bytes, new File(tmpFileName)); } catch (IOException e) { e.printStackTrace(); throw new OpenApiException("文件上传失败"); } } public void mergeFile(String taskId) { TaskModel taskModel = fileMap.get(taskId); if (taskModel == null) return; try { byte[] reduce = Files.list(Paths.get(tmpPath + "/" + taskId + "/")).sorted((s1, s2) -> { String[] split = s1.toString().split("\\\\"); String[] split2 = s2.toString().split("\\\\"); String[] split1 = split[split.length - 1].split("\\."); String[] split3 = split2[split2.length - 1].split("\\."); return Long.compare(Long.parseLong(split1[0]), (Long.parseLong(split3[0]))); }) .map(this::readByteByPath).filter(Objects::nonNull) .reduce(new byte[]{}, this::addBytes); String s = DigestUtils.md5DigestAsHex(reduce); if (!taskModel.getFileMd5().equals(s)) { throw new OpenApiException("MD5校验错误"); } FileCopyUtils.copy(reduce, new File(tmpPath + "/" + taskId + taskModel.getFileName())); deleteDir(tmpPath + taskId); } catch (IOException e) { e.printStackTrace(); } } private void deleteDir(String path) { try { Files.list(Paths.get(path)) .forEach(p -> { try { Files.delete(p); } catch (IOException e) { e.printStackTrace(); } }); Files.delete(Paths.get(path)); } catch (IOException e) { e.printStackTrace(); } } private byte[] readByteByPath(Path i) { try { return Files.readAllBytes(i); } catch (IOException e) { e.printStackTrace(); return null; } } public byte[] addBytes(byte[] data1, byte[] data2) { byte[] data3 = new byte[data1.length + data2.length]; System.arraycopy(data1, 0, data3, 0, data1.length); System.arraycopy(data2, 0, data3, data1.length, data2.length); return data3; } private void createTmpPath(String tmpPath) { Path path = Paths.get(tmpPath); if (Files.notExists(path)) { try { Files.createDirectory(path); } catch (IOException e) { e.printStackTrace(); throw new OpenApiException("创建失败"); } } } } -
DTO类public class CreateTaskDTO { private String fileName; private String fileMd5; private Long fileSize; getter setter... } public class TaskModel { private String id; // 上传文件ID private Long sliceNum; // 分片次数 private Integer size = 1024; // 每片大小 字节 private String fileMd5; // 文件MD5值 private String fileName; // 文件名称 private Long fileSize; // 文件大小 } public class UploadFileDTO { private String taskId; // 任务ID private String md5; // 分片文件MD5 private Integer no; // 分片文件序号 private MultipartFile file; // 分片文件 }
客户端代码
import com.alibaba.fastjson.JSONObject;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.io.FileSystemResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.util.DigestUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
import java.io.*;
import java.nio.file.Paths;
import java.util.function.BiConsumer;
@RunWith(SpringRunner.class)
@SpringBootTest
public class FileUploadTest {
@Autowired
private RestTemplate restTemplate;
String fileLocal = "01.png";
@Test
public void testCreateTaskId() {
String url = "http://127.0.0.1:9090/createTask";
File file = new File(fileLocal);
MultiValueMap<String, Object> param = new LinkedMultiValueMap<>();
param.add("fileName", fileLocal);
try {
param.add("fileMd5", DigestUtils.md5DigestAsHex(new FileInputStream(file)));
} catch (IOException e) {
e.printStackTrace();
}
param.add("fileSize", file.length());
HttpEntity<MultiValueMap<String, Object>> httpEntity = new HttpEntity<MultiValueMap<String, Object>>(param);
ResponseEntity<String> responseEntity = restTemplate.postForEntity(url, httpEntity, String.class);
System.out.println(responseEntity.getBody());
}
String s = "{\"id\":\"681aff74-ee43-4d2d-9488-7568854315c7\",\"sliceNum\":57,\"size\":1024,\"fileMd5\":\"55b1bfaa8360f333082956790a10ca8f\",\"fileName\":\"01.png\",\"fileSize\":58185}\n";
@Test
public void testUploadFile() throws Exception {
String url = "http://127.0.0.1:9090/uploadFile";
JSONObject jsonObject = JSONObject.parseObject(s);
String id = jsonObject.getString("id");
Integer sliceNum = jsonObject.getInteger("sliceNum");
Integer size = jsonObject.getInteger("size");
sliceFile(new File("01.png"), size, (no, bytes) -> {
MultiValueMap<String, Object> param = new LinkedMultiValueMap<>();
param.add("taskId", id);
param.add("md5", DigestUtils.md5DigestAsHex(bytes));
param.add("no", no);
File file = new File(id + ".tmp." + no);
try {
java.nio.file.Files.write(Paths.get(file.toURI()), bytes);
} catch (IOException e) {
e.printStackTrace();
}
FileSystemResource resource = new FileSystemResource(file);
param.add("file", resource);
sendFile(url, param);
});
}
@Test
public void mergeFile() {
JSONObject jsonObject = JSONObject.parseObject(s);
String id = jsonObject.getString("id");
MultiValueMap<String, Object> param = new LinkedMultiValueMap<>();
param.add("taskId", id);
sendFile("http://127.0.0.1:9090/mergeFile", param);
}
public void sendFile(String url, MultiValueMap<String, Object> param) {
HttpEntity<MultiValueMap<String, Object>> httpEntity = new HttpEntity<>(param);
ResponseEntity<String> responseEntity = restTemplate.postForEntity(url, httpEntity, String.class);
System.out.println(responseEntity.getBody());
}
public static void sliceFile(File file, int size, BiConsumer<Long, byte[]> consumer) {
RandomAccessFile randomAccessFile = null;
try {
randomAccessFile = new RandomAccessFile(file, "r");
} catch (FileNotFoundException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
long length = file.length();
long count = length / size;
if (file.length() % count != 0) {
count++;
}
long sum = 0;
for (long i = 0; i < count; i++) {
try {
byte[] bytes;
if (i + 1 == count) {
bytes = new byte[(int) (length - sum)];
randomAccessFile.read(bytes, 0, bytes.length);
} else {
bytes = new byte[size];
sum += size;
randomAccessFile.read(bytes, 0, size);
randomAccessFile.seek(sum);
}
consumer.accept(i, bytes);
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
里面的异常是自己定义的。