引言
数据源接入处理的流程大致是:从数据库按行读取数据源 -> 写入 csv 文件输出流 -> 处理 csv 文件 -> 处理后 InputStream 写入文件。如果源表数据量大,读取并写入 csv 文件时服务器内存有 OOM 风险,影响服务可用性。采取分批处理的思路解决内存管理问题。
task execute 主体框架
每个获取、处理、写入数据的工作流作为一个独立的任务执行。
public void execute(JobExecutionContext context) throws JobExecutionException {
try {
// 从 context 中读取任务参数和数据访问地址
// 初始化分批标志位
int currentRow = 0;
boolean hasNextInputStream = true;
while (hasNextInputStream) {
// 从数据库按行读取数据源
InputStream inputStream = fetchData(currentRow, inputFileLocation);
currentRow += ROW_SIZE;
if (isNull(inputStream)) {
hasNextInputStream = false;
break;
}
// 写入 csv 文件输出流
String originalFile = fileStorageEngine.store(inputStream, originalFileName);
// 处理 csv 文件并将处理后 InputStream 写入文件
TaskResultDTO taskResultDTO = solveInBatch(originalFile, taskParam);
// 记录日志
} catch (Exception e) {
throw new JobExecutionException(e);
}
}
其中,ROW_SIZE 表示每批从数据库中读取数据表的行数。
fetchData 获取数据
根据数据来源协议,选取不同方法获取数据,以 InputStream 的形式返回。
private InputStream fetchData(int currentRow, FileLocationDTO inputFileLocation) {
InputStream inputStream = null;
switch (inputFileLocation.getProtocol()) {
case FTP:
FTPParameters ftpParameters = inputFileLocation.getParameters().toJavaObject(FTPParameters.class);
String ftpUrl = ftpParameters.getUrl();
FTPClient ftpClient = new FTPClient();
if (ftpParameters.isHasAccount()) {
String[] subStrs = ftpUrl.split("/");
String ftpHost = subStrs[2];
ftpClient = FTPUtil.loginFTP(ftpHost, ftpParameters.getAccount(), ftpParameters.getPassword());
}
inputStream = FTPUtil.downloadAsInputStream(ftpClient, ftpUrl);
FTPUtil.logout(ftpClient);
break;
case HTTP:
HTTPParameters httpParameters = inputFileLocation.getParameters().toJavaObject(HTTPParameters.class);
ResponseEntity<Resource> responseEntity;
if (httpParameters.isHasAccount()) {
responseEntity = restTemplate.exchange(httpParameters.getUrl(), HttpMethod.GET, new HttpEntity<>(generateHeaderByParameter(httpParameters)), Resource.class);
} else {
responseEntity = restTemplate.getForEntity(httpParameters.getUrl(), Resource.class);
}
Resource resource = responseEntity.getBody();
if (isNull(resource)) {
StandardLoggingUtil.error(LOGGER, "fail to fetch data, its parameters:" + httpParameters);
} else {
try {
inputStream = resource.getInputStream();
} catch (IOException e) {
StandardLoggingUtil.error(LOGGER, "fail to fetch data, its parameters:" + httpParameters);
}
}
break;
case MySQL:
MysqlParameters mysqlParameters = inputFileLocation.getParameters().toJavaObject(MysqlParameters.class);
inputStream = queryTableByParameters(currentRow, mysqlParameters);
break;
case OSS:
OSSParameters ossParameters = inputFileLocation.getParameters().toJavaObject(OSSParameters.class);
inputStream = fetchOSSFileByParameters(ossParameters);
break;
default:
StandardLoggingUtil.info(LOGGER, "unsupported protocol:" + inputFileLocation.getProtocol());
}
return inputStream;
}
以 Mysql 为例,连接数据库,拼 sql 语句并执行,查询 ROW_SIZE 行数据。
private InputStream queryTableByParameters(int currentRow, MysqlParameters parameters) {
StructuredData data = new StructuredData();
ByteArrayOutputStream os = new ByteArrayOutputStream();
try (Connection conn = DriverManager.getConnection(parameters.getJDBCUrl(), parameters.getUsername(), parameters.getPassword())) {
try (Statement stmt = conn.createStatement()) {
String sql = format("SELECT * FROM `%s`.`%s`", parameters.getDatabase(), parameters.getTable()) + " LIMIT " + currentRow + " , " + ROW_SIZE;
ResultSet rs = stmt.executeQuery(sql);
ResultSetMetaData resultSetMetaData = rs.getMetaData();
int columnCount = resultSetMetaData.getColumnCount();
for (int i = 1; i <= columnCount; i++) {
data.addHeader(resultSetMetaData.getColumnName(i), i);
}
while (rs.next()) {
Map<String, Object> row = new HashMap<>(8);
for (int i = 1; i <= columnCount; i++) {
String name = resultSetMetaData.getColumnName(i);
String value = rs.getString(i);
row.put(name, value);
}
data.addRow(row);
}
if (data.getRows().size() == 0) {
return null;
}
data.toCsvOutputStream(os);
}
} catch (SQLException e) {
StandardLoggingUtil.info(LOGGER, "fail to to query data from Mysql, its parameters:" + parameters);
}
return new ByteArrayInputStream(os.toByteArray());
}
data.toCsvOutputStream(os) 方法将从数据库读取到的结构化数据写入 csv 文件。
public void toCsvOutputStream(OutputStream outputStream) {
List<Map<String, Object>> rows = this.getRows();
Preconditions.checkNotNull(rows);
Preconditions.checkArgument(!rows.isEmpty());
String[] header = this.getHeaderIndexes().entrySet().stream().sorted(Map.Entry.comparingByValue()).map(Map.Entry::getKey).toArray(String[]::new);
try (OutputStreamWriter outputStreamWriter = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8);
CSVWriter writer = new CSVWriter(outputStreamWriter)
) {
//首先写入数据头,就是列名称
writer.writeNext(header);
//然后按列名称的顺序写入实际数据,多行,每行是一个String[]
writeBody(rows, header, writer);
} catch (IOException e) {
throw new ErrorCodeException(ErrorCodes.fromErrorTemplate(CommonErrorTemplate.IO_EXCEPTION));
}
}
private void writeBody(List<Map<String, Object>> rows, String[] header, CSVWriter writer) {
List<String[]> bodies = new ArrayList<>(rows.size());
for (Map<String, ?> row : rows) {
String[] body = new String[row.size()];
for (int i = 0; i < header.length; i++) {
body[i] = String.valueOf(row.get(header[i]));
}
bodies.add(body);
}
writer.writeAll(bodies);
}
solveInBatch 分批处理并写入文件
为保证内容完整,需要处理分批断口处的残缺字符。
private TaskResultDTO solveInBatch(String originalFile, TaskParam taskParam) throws Exception {
String targetColumns = taskParam.getTargetColumns();
GatewayFileLocationDTO inputFileLocation = taskInfo.getInputFileLocation();
GatewayFileLocationDTO outputFileLocation = taskInfo.getOutputFileLocation();
// 初始化分批标志位
boolean hasNextBatch = true;
int offset = 0;
Map<String, Object> rv = null;
FileParserService fileParserService;
String resultFile = "";
TaskResultDTO taskResultDTO = new TaskResultDTO();
boolean success;
boolean isHeader = true;
Map<String, Integer> headerIndexes = null;
// 分批处理
while (hasNextBatch) {
rv = fileStorageEngine.retrieveFileInBatch(rv, originalFile, BATCH_SIZE, offset);
FileRetrieved fileRetrieved = generateFileRetrieved(originalFile, rv, offset);
fileParserService = fileParserServiceProvider.getComponentInstance(taskParam.getFileFormat());
StructuredData data = fileParserService.load(isHeader, headerIndexes, fileRetrieved.getInputStream());
if (isHeader) {
headerIndexes = data.getHeaderIndexes();
}
// 根据需求处理数据
// 处理结果写入一个新的文件
if (offset == 0) {
resultFile = saveResultFile(taskParam, data, columnNames);
} else {
resultFile = saveResultFileTail(resultFile, data, columnNames);
}
switch (offset) {
case 0:
success = writeDataToMysql(inputFileLocation, outputFileLocation, decryptedData);
break;
default:
success = writeTailToMysql(outputFileLocation, decryptedData);
break;
}
if (!success) {
StandardLoggingUtil.error(LOGGER, "fail to write data, its parameters:" + outputFileLocation.getParameters());
}
offset = (Integer) rv.get("nextOffset");
hasNextBatch = (Boolean) rv.get("hasNextBatch");
isHeader = false;
}
return taskResultDTO;
}
private FileRetrieved generateFileRetrieved(String originalFile, Map<String, Object> rv, int offset) {
FileRetrieved fileRetrieved = new FileRetrieved();
String fileName = Paths.get(originalFile).getFileName().toString();
fileRetrieved.setFileName(fileName.substring(fileName.indexOf(CONNECT_CHAR) + 1));
String content = (String) rv.get("content");
fileRetrieved.setInputStream(new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8)));
fileRetrieved.setOffset(offset);
return fileRetrieved;
}
private boolean writeDataToMysql(GatewayFileLocationDTO inputFileLocation, GatewayFileLocationDTO outputFileLocation, StructuredData data) {
boolean success = true;
MysqlParameters parameters = outputFileLocation.getParameters().toJavaObject(MysqlParameters.class);
DataSource dataSource = new DriverManagerDataSource(parameters.getJDBCUrl(), parameters.getUsername(), parameters.getPassword());
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
String tableName = format("%s_%s", parameters.getTable(), LocalDate.now().format(DateTimeFormatter.ofPattern("yyyyMMdd"))).replace('-', '_');
Map<String, String> columnCommentMap = queryColumnCommentByGatewayFileLocationDTO(inputFileLocation);
String tableSql = createTableSql(tableName, data.getHeaderIndexes(), columnCommentMap);
try {
jdbcTemplate.execute(tableSql);
SimpleDateFormat simpleDateFormat = null;
if (nonNull(dateFormat)) {
simpleDateFormat = new SimpleDateFormat(dateFormat);
}
for (Map<String, Object> row : data.getRows()) {
String insertSql = insertSql(tableName, row, simpleDateFormat);
jdbcTemplate.execute(insertSql);
}
} catch (DataAccessException e) {
success = false;
}
return success;
}
private boolean writeTailToMysql(GatewayFileLocationDTO outputFileLocation, StructuredData data) {
boolean success = true;
MysqlParameters parameters = outputFileLocation.getParameters().toJavaObject(MysqlParameters.class);
DataSource dataSource = new DriverManagerDataSource(parameters.getJDBCUrl(), parameters.getUsername(), parameters.getPassword());
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
String tableName = format("%s_%s", parameters.getTable(), LocalDate.now().format(DateTimeFormatter.ofPattern("yyyyMMdd"))).replace('-', '_');
try {
for (Map<String, Object> row : data.getRows()) {
SimpleDateFormat simpleDateFormat = null;
if (nonNull(dateFormat)) {
simpleDateFormat = new SimpleDateFormat(dateFormat);
}
String insertSql = insertSql(tableName, row, simpleDateFormat);
jdbcTemplate.execute(insertSql);
}
} catch (DataAccessException e) {
success = false;
}
return success;
}
private String saveResultFile(TaskParam taskParam, StructuredData data,
Map<Integer, Set<String>> columnNames) {
List<Map<String, Object>> rows = data.getRows();
int size = rows.size();
List<Map<String, Object>> coloredData = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
Map<String, Object> row = rows.get(i);
Map<String, Object> coloredRow = new HashMap<>(row.size());
for (Entry<String, Object> entry : row.entrySet()) {
if (columnNames.containsKey(i) && columnNames.get(i).contains(entry.getKey())) {
coloredRow.put(entry.getKey(), colorRed(entry.getValue()));
} else {
coloredRow.put(entry.getKey(), entry.getValue());
}
}
coloredData.add(coloredRow);
}
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
writeRowsToMarkdown(coloredData, outputStream);
return fileStorageEngine.store(new ByteArrayInputStream(outputStream.toByteArray()),
taskParam.getTaskUuid(), "R");
}
private String saveResultFileTail(String resultFile, StructuredData data, Map<Integer, Set<String>> columnNames) {
List<Map<String, Object>> rows = data.getRows();
int size = rows.size();
List<Map<String, Object>> coloredData = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
Map<String, Object> row = rows.get(i);
Map<String, Object> coloredRow = new HashMap<>(row.size());
for (Entry<String, Object> entry : row.entrySet()) {
if (columnNames.containsKey(i) && columnNames.get(i).contains(entry.getKey())) {
coloredRow.put(entry.getKey(), colorRed(entry.getValue()));
} else {
coloredRow.put(entry.getKey(), entry.getValue());
}
}
coloredData.add(coloredRow);
}
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
writeRowsToMarkdown(coloredData, outputStream);
return fileStorageEngine.append(new ByteArrayInputStream(outputStream.toByteArray()), resultFile);
}
其中,BATCH_SIZE 表示分批大小,taskParam.getFileFormat() 包括 CSV 和 MYSQL 两种。获取文件:
public Map<String, Object> retrieveFileInBatch(Map<String, Object> rv, String filePath, int batchSize, int offset) throws IOException {
Preconditions.checkNotNull(filePath);
try {
FileRetrieved fileRetrieved = new FileRetrieved();
String fileName = Paths.get(filePath).getFileName().toString();
fileRetrieved.setFileName(fileName.substring(fileName.indexOf(CONNECT_CHAR) + 1));
File dataFile = new File(filePath);
byte[] lastByte = {};
InputStream input = new FileInputStream(dataFile);
if (rv != null && rv.containsKey("lastByte")) {
lastByte = (byte[]) rv.get("lastByte");
}
rv = loadContent(batchSize, offset, input, lastByte);
return rv;
} catch (FileNotFoundException e) {
throw new ErrorCodeException(fromErrorTemplate(FILE_NOT_FOUND, filePath), e);
}
}
处理首尾残缺字符:
private static Map<String, Object> loadContent(Integer batchSize, Integer offset,
InputStream input, byte[] la) throws IOException {
Map<String, Object> rv = new HashMap<>();
byte[] buffer = new byte[la.length + batchSize];
copyToBuffer(buffer, la);
int totalLength = input.available();
//根据offset跳过InputStream前面的字符
input.skip(offset);
Integer read = IOUtils.read(input, buffer, la.length, batchSize);
// 断点如果不是完整的一行记录,需要暂存待处理下一批时拼接
byte checkpoint = 10; // 换行符
int lastN = ArrayUtils.lastIndexOf(buffer, checkpoint);
byte[] conArray = ArrayUtils.subarray(buffer, 0, lastN);
byte[] lastByte = ArrayUtils.subarray(buffer, lastN, buffer.length);
String content = new String(conArray, 0, conArray.length);
int nextOffset = offset + lastByte.length + conArray.length;
boolean hasNextBatch = true;
if (read + offset >= totalLength) {
hasNextBatch = false;
}
rv.put("hasNextBatch", hasNextBatch);
rv.put("content", content);
rv.put("lastByte", lastByte);
rv.put("nextOffset", nextOffset);
return rv;
}
总结
综上,在数据源接入处理的读取数据源、写入 csv 文件、处理数据、写入数据库各个环节,采用分批处理的方式进行内存管理,有效避免内存溢出问题。