Java 多线程在大文件导出、批量插入与多表查询场景实战

406 阅读7分钟

Java 多线程在大文件导出、批量插入与多表查询场景中的深度应用

在 Java 开发领域,多线程技术是提升应用性能和处理效率的重要手段。尤其是在面对大文件导出、批量插入数据以及多表查询大量数据等复杂场景时,合理运用多线程能够有效利用系统资源,缩短任务执行时间。接下来,我们将结合这三个典型场景,详细分析建表方案,并给出包含详细注释的核心代码。

一、大文件导出场景

1.1 建表设计

假设我们要导出的是用户数据,包括用户的基本信息和订单信息。我们可以设计两张表:user_info表用于存储用户基本信息,order_info表用于存储用户订单信息。

user_info表结构设计如下:

字段名数据类型说明
user_idbigint用户 ID,主键
user_namevarchar(50)用户名
ageint年龄
gendertinyint性别(0:男,1:女)
emailvarchar(100)邮箱

order_info表结构设计如下:

字段名数据类型说明
order_idbigint订单 ID,主键
user_idbigint用户 ID,外键,关联user_info表的user_id
order_amountdecimal(10, 2)订单金额
order_datedatetime订单日期

在 MySQL 中创建这两张表的 SQL 语句如下:

CREATE TABLE user_info (
    user_id bigint PRIMARY KEY,
    user_name varchar(50) NOT NULL,
    age int,
    gender tinyint,
    email varchar(100)
);
CREATE TABLE order_info (
    order_id bigint PRIMARY KEY,
    user_id bigint,
    order_amount decimal(10, 2),
    order_date datetime,
    FOREIGN KEY (user_id) REFERENCES user_info(user_id)
);

1.2 核心代码实现

我们使用 Java 的多线程来实现大文件导出,将数据分块写入文件,提高导出效率。以下是核心代码:

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class BigFileExport {
    private static final String JDBC_URL = "jdbc:mysql://localhost:3306/your_database?useSSL=false&serverTimezone=UTC";
    private static final String JDBC_USER = "your_username";
    private static final String JDBC_PASSWORD = "your_password";
    // 每个线程处理的数据量
    private static final int DATA_PER_THREAD = 1000;
    // 导出文件路径
    private static final String EXPORT_FILE_PATH = "big_file_export.csv";
    // 定义Callable接口实现类,用于每个线程执行的任务
    static class ExportTask implements Callable<Void> {
        private int startIndex;
        public ExportTask(int startIndex) {
            this.startIndex = startIndex;
        }
        @Override
        public Void call() throws Exception {
            try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
                 PreparedStatement statement = connection.prepareStatement("SELECT ui.user_name, ui.age, ui.gender, ui.email, oi.order_amount, oi.order_date " +
                         "FROM user_info ui " +
                         "JOIN order_info oi ON ui.user_id = oi.user_id " +
                         "LIMIT?,?");
                 BufferedWriter writer = new BufferedWriter(new FileWriter(EXPORT_FILE_PATH, true))) {
                statement.setInt(1, startIndex);
                statement.setInt(2, DATA_PER_THREAD);
                ResultSet resultSet = statement.executeQuery();
                while (resultSet.next()) {
                    // 拼接数据行
                    String dataRow = resultSet.getString("user_name") + "," +
                            resultSet.getInt("age") + "," +
                            resultSet.getInt("gender") + "," +
                            resultSet.getString("email") + "," +
                            resultSet.getBigDecimal("order_amount") + "," +
                            resultSet.getTimestamp("order_date") + "\n";
                    writer.write(dataRow);
                }
            } catch (SQLException | IOException e) {
                e.printStackTrace();
            }
            return null;
        }
    }
    public static void main(String[] args) {
        try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
             PreparedStatement countStatement = connection.prepareStatement("SELECT COUNT(*) FROM user_info");
             ResultSet countResultSet = countStatement.executeQuery()) {
            countResultSet.next();
            int totalCount = countResultSet.getInt(1);
            int threadCount = (int) Math.ceil((double) totalCount / DATA_PER_THREAD);
            ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
            List<Future<Void>> futures = new ArrayList<>();
            for (int i = 0; i < threadCount; i++) {
                int startIndex = i * DATA_PER_THREAD;
                futures.add(executorService.submit(new ExportTask(startIndex)));
            }
            executorService.shutdown();
            for (Future<Void> future : futures) {
                try {
                    future.get();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
            System.out.println("大文件导出完成!");
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
}

代码注释说明:

  1. 常量定义:定义了数据库连接信息、每个线程处理的数据量以及导出文件路径。
  1. ExportTask :实现Callable接口,定义每个线程执行的任务。在call方法中,通过数据库查询获取对应的数据块,然后将数据写入导出文件。
  1. main 方法
    • 首先获取数据总条数,计算需要的线程数量。
    • 创建线程池,并提交多个ExportTask任务。
    • 等待所有任务执行完成后,关闭线程池并输出导出完成信息。

二、批量插入数据场景

2.1 建表设计

假设我们要批量插入的是商品信息,创建product_info表用于存储商品数据。

product_info表结构设计如下:

字段名数据类型说明
product_idbigint商品 ID,主键
product_namevarchar(100)商品名称
pricedecimal(10, 2)商品价格
stockint商品库存
category_idbigint商品分类 ID

在 MySQL 中创建该表的 SQL 语句如下:

CREATE TABLE product_info (
    product_id bigint PRIMARY KEY,
    product_name varchar(100) NOT NULL,
    price decimal(10, 2),
    stock int,
    category_id bigint
);

2.2 核心代码实现

使用多线程批量插入数据,可以加快数据插入速度。以下是核心代码:

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class BatchInsertData {
    private static final String JDBC_URL = "jdbc:mysql://localhost:3306/your_database?useSSL=false&serverTimezone=UTC";
    private static final String JDBC_USER = "your_username";
    private static final String JDBC_PASSWORD = "your_password";
    // 每个线程插入的数据量
    private static final int DATA_PER_THREAD = 1000;
    // 模拟生成要插入的数据
    private static List<Product> generateData() {
        List<Product> productList = new ArrayList<>();
        for (long i = 1; i <= 10000; i++) {
            productList.add(new Product(i, "Product_" + i, 9.99, 100, i % 10));
        }
        return productList;
    }
    // 定义产品类
    static class Product {
        long product_id;
        String product_name;
        double price;
        int stock;
        long category_id;
        public Product(long product_id, String product_name, double price, int stock, long category_id) {
            this.product_id = product_id;
            this.product_name = product_name;
            this.price = price;
            this.stock = stock;
            this.category_id = category_id;
        }
    }
    // 定义Callable接口实现类,用于每个线程执行的任务
    static class InsertTask implements Callable<Void> {
        private List<Product> dataList;
        private int startIndex;
        public InsertTask(List<Product> dataList, int startIndex) {
            this.dataList = dataList;
            this.startIndex = startIndex;
        }
        @Override
        public Void call() throws Exception {
            try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
                 PreparedStatement statement = connection.prepareStatement("INSERT INTO product_info (product_id, product_name, price, stock, category_id) VALUES (?,?,?,?,?)")) {
                for (int i = startIndex; i < startIndex + DATA_PER_THREAD && i < dataList.size(); i++) {
                    Product product = dataList.get(i);
                    statement.setLong(1, product.product_id);
                    statement.setString(2, product.product_name);
                    statement.setDouble(3, product.price);
                    statement.setInt(4, product.stock);
                    statement.setLong(5, product.category_id);
                    statement.addBatch();
                }
                statement.executeBatch();
            } catch (SQLException e) {
                e.printStackTrace();
            }
            return null;
        }
    }
    public static void main(String[] args) {
        List<Product> productList = generateData();
        int dataSize = productList.size();
        int threadCount = (int) Math.ceil((double) dataSize / DATA_PER_THREAD);
        ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
        List<Future<Void>> futures = new ArrayList<>();
        for (int i = 0; i < threadCount; i++) {
            int startIndex = i * DATA_PER_THREAD;
            futures.add(executorService.submit(new InsertTask(productList, startIndex)));
        }
        executorService.shutdown();
        for (Future<Void> future : futures) {
            try {
                future.get();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        System.out.println("批量插入数据完成!");
    }
}

代码注释说明:

  1. 常量定义:定义了数据库连接信息和每个线程插入的数据量。
  1. generateData 方法:模拟生成要插入的商品数据。
  1. Product :定义商品实体类,用于存储商品信息。
  1. InsertTask :实现Callable接口,定义每个线程执行的插入任务。在call方法中,从数据列表中获取对应的数据块,通过PreparedStatement的addBatch和executeBatch方法批量插入数据。
  1. main 方法
    • 首先生成数据,计算线程数量。
    • 创建线程池,并提交多个InsertTask任务。
    • 等待所有任务执行完成后,关闭线程池并输出插入完成信息。

三、多表查询大量数据场景

3.1 建表设计

我们继续使用前面大文件导出场景中的user_info表和order_info表,通过多表查询获取用户及其订单的相关信息。

3.2 核心代码实现

使用多线程进行多表查询,可以并行处理不同的数据块,提高查询效率。以下是核心代码:

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class MultiTableQuery {
    private static final String JDBC_URL = "jdbc:mysql://localhost:3306/your_database?useSSL=false&serverTimezone=UTC";
    private static final String JDBC_USER = "your_username";
    private static final String JDBC_PASSWORD = "your_password";
    // 每个线程处理的数据量
    private static final int DATA_PER_THREAD = 1000;
    // 定义Callable接口实现类,用于每个线程执行的任务
    static class QueryTask implements Callable<List<UserOrder>> {
        private int startIndex;
        public QueryTask(int startIndex) {
            this.startIndex = startIndex;
        }
        @Override
        public List<UserOrder> call() throws Exception {
            List<UserOrder> userOrderList = new ArrayList<>();
            try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
                 PreparedStatement statement = connection.prepareStatement("SELECT ui.user_name, ui.age, ui.gender, ui.email, oi.order_amount, oi.order_date " +
                         "FROM user_info ui " +
                         "JOIN order_info oi ON ui.user_id = oi.user_id " +
                         "LIMIT?,?")) {
                statement.setInt(1, startIndex);
                statement.setInt(2, DATA_PER_THREAD);
                ResultSet resultSet = statement.executeQuery();
                while (resultSet.next()) {
                    UserOrder userOrder = new UserOrder();
                    userOrder.userName = resultSet.getString("user_name");
                    userOrder.age = resultSet.getInt("age");
                    userOrder.gender = resultSet.getInt("gender");
                    userOrder.email = resultSet.getString("email");
                    userOrder.orderAmount = resultSet.getBigDecimal("order_amount");
                    userOrder.orderDate = resultSet.getTimestamp("order_date");
                    userOrderList.add(userOrder);
                }
            } catch (SQLException e) {
                e.printStackTrace();
            }
            return userOrderList;
        }
    }
    // 定义用户订单类
    static class UserOrder {
        String userName;
        int age;
        int gender;
        String email;
        java.math.BigDecimal orderAmount;
        java.sql.Timestamp orderDate;
    }
    public static void main(String[] args) {
        try (Connection connection = DriverManager.getConnection(JDBC_URL, JDBC_USER, JDBC_PASSWORD);
             PreparedStatement countStatement = connection.prepareStatement("SELECT COUNT(*) FROM user_info");
             ResultSet countResultSet = countStatement.executeQuery()) {
            countResultSet.next();
            int totalCount = countResultSet.getInt(1);
            int threadCount = (int) Math.ceil((double) totalCount / DATA_PER_THREAD);
            ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
            List<Future<List<UserOrder>>> futures = new ArrayList<>();
            for (int i = 0; i < threadCount; i++) {
                int startIndex = i * DATA_PER_THREAD;
                futures.add(executorService.submit(new QueryTask(startIndex)));
            }
            executorService.shutdown();
            List<UserOrder> allUserOrders = new ArrayList<>();
            for (Future<List<UserOrder>> future : futures) {
                try {
                    allUserOrders.addAll(future.get());
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
            System.out.println("多表查询大量数据完成,共获取 " + allUserOrders.size() + " 条数据!");
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
}

代码注释说明:

  1. 常量定义:定义了数据库连接信息和每个线程处理的数据量。
  1. QueryTask :实现Callable接口,定义每个线程执行的查询任务。在call方法中,通过数据库查询获取对应的数据块,并将查询结果封装成UserOrder对象,添加到列表中返回。
  1. UserOrder :定义用户订单实体类,用于存储查询结果中的用户和订单信息。
  1. main 方法
    • 首先获取数据总条数,计算需要的线程数量。
    • 创建线程池,并提交多个QueryTask任务。
    • 等待所有任务执行完成后,将各个线程的查询结果合并,关闭线程池并输出