Java 多线程在大文件导出、批量插入与多表查询场景中的深度应用
在 Java 开发领域,多线程技术是提升应用性能和处理效率的重要手段。尤其是在面对大文件导出、批量插入数据以及多表查询大量数据等复杂场景时,合理运用多线程能够有效利用系统资源,缩短任务执行时间。接下来,我们将结合这三个典型场景,详细分析建表方案,并给出包含详细注释的核心代码。
一、大文件导出场景
1.1 建表设计
假设我们要导出的是用户数据,包括用户的基本信息和订单信息。我们可以设计两张表:user_info表用于存储用户基本信息,order_info表用于存储用户订单信息。
user_info表结构设计如下:
| 字段名 | 数据类型 | 说明 |
|---|---|---|
| user_id | bigint | 用户 ID,主键 |
| user_name | varchar(50) | 用户名 |
| age | int | 年龄 |
| gender | tinyint | 性别(0:男,1:女) |
| varchar(100) | 邮箱 |
order_info表结构设计如下:
| 字段名 | 数据类型 | 说明 |
|---|---|---|
| order_id | bigint | 订单 ID,主键 |
| user_id | bigint | 用户 ID,外键,关联user_info表的user_id |
| order_amount | decimal(10, 2) | 订单金额 |
| order_date | datetime | 订单日期 |
在 MySQL 中创建这两张表的 SQL 语句如下:
CREATE TABLE user_info (
user_id bigint PRIMARY KEY,
user_name varchar(50) NOT NULL,
age int,
gender tinyint,
email varchar(100)
);
CREATE TABLE order_info (
order_id bigint PRIMARY KEY,
user_id bigint,
order_amount decimal(10, 2),
order_date datetime,
FOREIGN KEY (user_id) REFERENCES user_info(user_id)
);
1.2 核心代码实现
我们使用 Java 的多线程来实现大文件导出,将数据分块写入文件,提高导出效率。以下是核心代码:
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class BigFileExport {
private static final String JDBC_URL = "jdbc:mysql://localhost:3306/your_database?useSSL=false&serverTimezone=UTC";
private static final String JDBC_USER = "your_username";
private static final String JDBC_PASSWORD = "your_password";
// 每个线程处理的数据量
private static final int DATA_PER_THREAD = 1000;
// 导出文件路径
private static final String EXPORT_FILE_PATH = "big_file_export.csv";
// 定义Callable接口实现类,用于每个线程执行的任务
static class ExportTask implements Callable<Void> {
private int startIndex;
public ExportTask(int startIndex) {
this.startIndex = startIndex;
}
@Override
public Void call() throws Exception {
try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
PreparedStatement statement = connection.prepareStatement("SELECT ui.user_name, ui.age, ui.gender, ui.email, oi.order_amount, oi.order_date " +
"FROM user_info ui " +
"JOIN order_info oi ON ui.user_id = oi.user_id " +
"LIMIT?,?");
BufferedWriter writer = new BufferedWriter(new FileWriter(EXPORT_FILE_PATH, true))) {
statement.setInt(1, startIndex);
statement.setInt(2, DATA_PER_THREAD);
ResultSet resultSet = statement.executeQuery();
while (resultSet.next()) {
// 拼接数据行
String dataRow = resultSet.getString("user_name") + "," +
resultSet.getInt("age") + "," +
resultSet.getInt("gender") + "," +
resultSet.getString("email") + "," +
resultSet.getBigDecimal("order_amount") + "," +
resultSet.getTimestamp("order_date") + "\n";
writer.write(dataRow);
}
} catch (SQLException | IOException e) {
e.printStackTrace();
}
return null;
}
}
public static void main(String[] args) {
try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
PreparedStatement countStatement = connection.prepareStatement("SELECT COUNT(*) FROM user_info");
ResultSet countResultSet = countStatement.executeQuery()) {
countResultSet.next();
int totalCount = countResultSet.getInt(1);
int threadCount = (int) Math.ceil((double) totalCount / DATA_PER_THREAD);
ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
List<Future<Void>> futures = new ArrayList<>();
for (int i = 0; i < threadCount; i++) {
int startIndex = i * DATA_PER_THREAD;
futures.add(executorService.submit(new ExportTask(startIndex)));
}
executorService.shutdown();
for (Future<Void> future : futures) {
try {
future.get();
} catch (Exception e) {
e.printStackTrace();
}
}
System.out.println("大文件导出完成!");
} catch (SQLException e) {
e.printStackTrace();
}
}
}
代码注释说明:
- 常量定义:定义了数据库连接信息、每个线程处理的数据量以及导出文件路径。
- ExportTask 类:实现Callable接口,定义每个线程执行的任务。在call方法中,通过数据库查询获取对应的数据块,然后将数据写入导出文件。
- main 方法:
-
- 首先获取数据总条数,计算需要的线程数量。
-
- 创建线程池,并提交多个ExportTask任务。
-
- 等待所有任务执行完成后,关闭线程池并输出导出完成信息。
二、批量插入数据场景
2.1 建表设计
假设我们要批量插入的是商品信息,创建product_info表用于存储商品数据。
product_info表结构设计如下:
| 字段名 | 数据类型 | 说明 |
|---|---|---|
| product_id | bigint | 商品 ID,主键 |
| product_name | varchar(100) | 商品名称 |
| price | decimal(10, 2) | 商品价格 |
| stock | int | 商品库存 |
| category_id | bigint | 商品分类 ID |
在 MySQL 中创建该表的 SQL 语句如下:
CREATE TABLE product_info (
product_id bigint PRIMARY KEY,
product_name varchar(100) NOT NULL,
price decimal(10, 2),
stock int,
category_id bigint
);
2.2 核心代码实现
使用多线程批量插入数据,可以加快数据插入速度。以下是核心代码:
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class BatchInsertData {
private static final String JDBC_URL = "jdbc:mysql://localhost:3306/your_database?useSSL=false&serverTimezone=UTC";
private static final String JDBC_USER = "your_username";
private static final String JDBC_PASSWORD = "your_password";
// 每个线程插入的数据量
private static final int DATA_PER_THREAD = 1000;
// 模拟生成要插入的数据
private static List<Product> generateData() {
List<Product> productList = new ArrayList<>();
for (long i = 1; i <= 10000; i++) {
productList.add(new Product(i, "Product_" + i, 9.99, 100, i % 10));
}
return productList;
}
// 定义产品类
static class Product {
long product_id;
String product_name;
double price;
int stock;
long category_id;
public Product(long product_id, String product_name, double price, int stock, long category_id) {
this.product_id = product_id;
this.product_name = product_name;
this.price = price;
this.stock = stock;
this.category_id = category_id;
}
}
// 定义Callable接口实现类,用于每个线程执行的任务
static class InsertTask implements Callable<Void> {
private List<Product> dataList;
private int startIndex;
public InsertTask(List<Product> dataList, int startIndex) {
this.dataList = dataList;
this.startIndex = startIndex;
}
@Override
public Void call() throws Exception {
try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
PreparedStatement statement = connection.prepareStatement("INSERT INTO product_info (product_id, product_name, price, stock, category_id) VALUES (?,?,?,?,?)")) {
for (int i = startIndex; i < startIndex + DATA_PER_THREAD && i < dataList.size(); i++) {
Product product = dataList.get(i);
statement.setLong(1, product.product_id);
statement.setString(2, product.product_name);
statement.setDouble(3, product.price);
statement.setInt(4, product.stock);
statement.setLong(5, product.category_id);
statement.addBatch();
}
statement.executeBatch();
} catch (SQLException e) {
e.printStackTrace();
}
return null;
}
}
public static void main(String[] args) {
List<Product> productList = generateData();
int dataSize = productList.size();
int threadCount = (int) Math.ceil((double) dataSize / DATA_PER_THREAD);
ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
List<Future<Void>> futures = new ArrayList<>();
for (int i = 0; i < threadCount; i++) {
int startIndex = i * DATA_PER_THREAD;
futures.add(executorService.submit(new InsertTask(productList, startIndex)));
}
executorService.shutdown();
for (Future<Void> future : futures) {
try {
future.get();
} catch (Exception e) {
e.printStackTrace();
}
}
System.out.println("批量插入数据完成!");
}
}
代码注释说明:
- 常量定义:定义了数据库连接信息和每个线程插入的数据量。
- generateData 方法:模拟生成要插入的商品数据。
- Product 类:定义商品实体类,用于存储商品信息。
- InsertTask 类:实现Callable接口,定义每个线程执行的插入任务。在call方法中,从数据列表中获取对应的数据块,通过PreparedStatement的addBatch和executeBatch方法批量插入数据。
- main 方法:
-
- 首先生成数据,计算线程数量。
-
- 创建线程池,并提交多个InsertTask任务。
-
- 等待所有任务执行完成后,关闭线程池并输出插入完成信息。
三、多表查询大量数据场景
3.1 建表设计
我们继续使用前面大文件导出场景中的user_info表和order_info表,通过多表查询获取用户及其订单的相关信息。
3.2 核心代码实现
使用多线程进行多表查询,可以并行处理不同的数据块,提高查询效率。以下是核心代码:
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class MultiTableQuery {
private static final String JDBC_URL = "jdbc:mysql://localhost:3306/your_database?useSSL=false&serverTimezone=UTC";
private static final String JDBC_USER = "your_username";
private static final String JDBC_PASSWORD = "your_password";
// 每个线程处理的数据量
private static final int DATA_PER_THREAD = 1000;
// 定义Callable接口实现类,用于每个线程执行的任务
static class QueryTask implements Callable<List<UserOrder>> {
private int startIndex;
public QueryTask(int startIndex) {
this.startIndex = startIndex;
}
@Override
public List<UserOrder> call() throws Exception {
List<UserOrder> userOrderList = new ArrayList<>();
try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
PreparedStatement statement = connection.prepareStatement("SELECT ui.user_name, ui.age, ui.gender, ui.email, oi.order_amount, oi.order_date " +
"FROM user_info ui " +
"JOIN order_info oi ON ui.user_id = oi.user_id " +
"LIMIT?,?")) {
statement.setInt(1, startIndex);
statement.setInt(2, DATA_PER_THREAD);
ResultSet resultSet = statement.executeQuery();
while (resultSet.next()) {
UserOrder userOrder = new UserOrder();
userOrder.userName = resultSet.getString("user_name");
userOrder.age = resultSet.getInt("age");
userOrder.gender = resultSet.getInt("gender");
userOrder.email = resultSet.getString("email");
userOrder.orderAmount = resultSet.getBigDecimal("order_amount");
userOrder.orderDate = resultSet.getTimestamp("order_date");
userOrderList.add(userOrder);
}
} catch (SQLException e) {
e.printStackTrace();
}
return userOrderList;
}
}
// 定义用户订单类
static class UserOrder {
String userName;
int age;
int gender;
String email;
java.math.BigDecimal orderAmount;
java.sql.Timestamp orderDate;
}
public static void main(String[] args) {
try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
PreparedStatement countStatement = connection.prepareStatement("SELECT COUNT(*) FROM user_info");
ResultSet countResultSet = countStatement.executeQuery()) {
countResultSet.next();
int totalCount = countResultSet.getInt(1);
int threadCount = (int) Math.ceil((double) totalCount / DATA_PER_THREAD);
ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
List<Future<List<UserOrder>>> futures = new ArrayList<>();
for (int i = 0; i < threadCount; i++) {
int startIndex = i * DATA_PER_THREAD;
futures.add(executorService.submit(new QueryTask(startIndex)));
}
executorService.shutdown();
List<UserOrder> allUserOrders = new ArrayList<>();
for (Future<List<UserOrder>> future : futures) {
try {
allUserOrders.addAll(future.get());
} catch (Exception e) {
e.printStackTrace();
}
}
System.out.println("多表查询大量数据完成,共获取 " + allUserOrders.size() + " 条数据!");
} catch (SQLException e) {
e.printStackTrace();
}
}
}
代码注释说明:
- 常量定义:定义了数据库连接信息和每个线程处理的数据量。
- QueryTask 类:实现Callable接口,定义每个线程执行的查询任务。在call方法中,通过数据库查询获取对应的数据块,并将查询结果封装成UserOrder对象,添加到列表中返回。
- UserOrder 类:定义用户订单实体类,用于存储查询结果中的用户和订单信息。
- main 方法:
-
- 首先获取数据总条数,计算需要的线程数量。
-
- 创建线程池,并提交多个QueryTask任务。
-
- 等待所有任务执行完成后,将各个线程的查询结果合并,关闭线程池并输出