Elasticsearch 与机器学习结合:实现高效模型推理的方案(上)

94 阅读21分钟

在大数据时代,搜索引擎与机器学习的融合已成为数据处理领域的重要技术方向。Elasticsearch 不仅是高性能的搜索与分析引擎,还提供了完整的机器学习推理框架,使我们能在分布式环境中高效部署和运行 ML 模型。

1. Elasticsearch 机器学习功能概述

Elasticsearch 的机器学习能力随版本不断发展,从早期的异常检测功能,到如今成熟的模型推理框架,为数据分析提供了强大支持。

核心功能包括:

  • 异常检测:自动识别数据中的异常模式和离群值
  • 预测分析:基于历史数据进行时序预测和趋势分析
  • 模型推理:在 ES 中部署外部训练的机器学习模型并执行实时或批量推理

ES 机器学习框架的主要优势在于能够利用其分布式架构实现高可用、高扩展的模型部署,无需额外维护专门的 ML 服务基础设施。

版本功能对照

特性ES 7.xES 8.0-8.3ES 8.4+ES 8.8+
异常检测
时序预测有限支持
ONNX 模型支持部分支持完整支持完整支持
PyTorch 模型原生支持
分布式推理有限支持完整支持增强支持
NLP 预训练模型库有限支持丰富支持全面支持
推理内存限制控制增强控制精细控制
模型部署隔离性有限支持增强支持

2. 推理功能原理与支持的模型类型

推理功能原理.png

2.1 支持的模型类型

Elasticsearch 当前支持以下类型的机器学习模型:

  1. PyTorch 模型(通过 ONNX 格式转换或 8.8+版本原生支持)
  2. scikit-learn 模型(支持分类器、回归器等)
  3. XGBoost 模型(分类与回归)
  4. LightGBM 模型(分类与回归)
  5. 预训练的 NLP 模型(BERT、sentence-transformers 等)

2.2 ONNX 操作支持说明

对于 ONNX 模型,Elasticsearch 支持的操作集包括:

  • 基础操作:Add, Sub, Mul, Div, Pow
  • 神经网络层:Conv, MaxPool, BatchNormalization, Dropout
  • 激活函数:Relu, Sigmoid, Tanh, LeakyRelu
  • 数据操作:Reshape, Transpose, Concat, Split
  • 序列处理:LSTM, GRU (从 8.4 版本开始完整支持)

在导入模型前应确认模型使用的操作在 ES 支持列表中。可以通过 ONNX Runtime 的检查工具验证模型兼容性:

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;

public class ModelCompatibilityChecker {
    public static boolean isModelCompatible(String modelPath) {
        try {
            OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();

            // 加载模型进行验证
            OrtSession session = env.createSession(modelPath, options);

            // 检查模型输入输出
            session.getInputNames().forEach(name ->
                System.out.println("输入: " + name));
            session.getOutputNames().forEach(name ->
                System.out.println("输出: " + name));

            session.close();
            env.close();
            return true;
        } catch (Exception e) {
            System.err.println("模型兼容性检查失败: " + e.getMessage());
            return false;
        }
    }
}

2.3 推理处理流程

推理过程通过 Elasticsearch 的 Ingest Pipeline(摄入管道)实现:

  1. 数据文档通过管道时,推理处理器被触发
  2. 处理器从文档中提取特征数据
  3. 数据传入模型进行推理
  4. 推理结果存储在目标字段
  5. 处理后的文档继续完成索引流程

3. 实现步骤与环境准备

实现步骤.png

3.1 环境准备

  1. Elasticsearch 8.x(建议 8.4+版本)
  2. Java 11+
  3. 适配 ES 版本的 Java 客户端

Maven 依赖配置:

<dependencies>
    <!-- Elasticsearch Java客户端 -->
    <dependency>
        <groupId>co.elastic.clients</groupId>
        <artifactId>elasticsearch-java</artifactId>
        <version>8.8.0</version>
    </dependency>

    <!-- JSON处理 -->
    <dependency>
        <groupId>com.fasterxml.jackson.core</groupId>
        <artifactId>jackson-databind</artifactId>
        <version>2.14.2</version>
    </dependency>

    <!-- 日志框架 -->
    <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-api</artifactId>
        <version>1.7.36</version>
    </dependency>
    <dependency>
        <groupId>ch.qos.logback</groupId>
        <artifactId>logback-classic</artifactId>
        <version>1.2.11</version>
    </dependency>

    <!-- ONNX Runtime (用于模型验证和优化) -->
    <dependency>
        <groupId>com.microsoft.onnxruntime</groupId>
        <artifactId>onnxruntime</artifactId>
        <version>1.13.1</version>
    </dependency>
</dependencies>

3.2 日志配置示例

为确保生产环境中的日志管理规范,添加以下 logback.xml 配置:

<?xml version="1.0" encoding="UTF-8"?>
<configuration>
    <appender name="CONSOLE" class="ch.qos.logback.core.ConsoleAppender">
        <encoder>
            <pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
        </encoder>
    </appender>

    <appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
        <file>logs/es-ml-app.log</file>
        <rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
            <fileNamePattern>logs/es-ml-app.%d{yyyy-MM-dd}.log</fileNamePattern>
            <maxHistory>30</maxHistory>
        </rollingPolicy>
        <encoder>
            <pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
        </encoder>
    </appender>

    <logger name="co.elastic" level="INFO"/>
    <logger name="org.elasticsearch" level="WARN"/>
    <logger name="com.yourcompany.esml" level="DEBUG"/>

    <root level="INFO">
        <appender-ref ref="CONSOLE"/>
        <appender-ref ref="FILE"/>
    </root>
</configuration>

3.3 Elasticsearch 配置参数

高效运行 ML 推理需要合理配置以下关键参数:

# elasticsearch.yml配置示例
xpack.ml.max_model_memory_limit: 1gb      # 单个模型最大内存
xpack.ml.max_machine_memory_percent: 30   # ML可使用的最大机器内存百分比
xpack.ml.max_inference_processors: 4      # 每节点最大推理处理器数量
thread_pool.ingest.queue_size: 200        # 推理请求队列大小
thread_pool.ingest.size: 8                # 推理线程池大小

3.4 索引映射定义

为推理结果创建合适的索引映射,确保字段类型正确并支持高效查询:

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;

public class IndexManager {
    private static final Logger logger = LoggerFactory.getLogger(IndexManager.class);

    /**
     * 创建推理结果索引模板
     */
    public void createInferenceIndexTemplate(ElasticsearchClient client) throws IOException {
        try {
            client.indices().putIndexTemplate(t -> t
                .name("ml-inference-template")
                .indexPatterns("inference-*", "ml-results-*")
                .template(it -> it
                    .settings(s -> s
                        .numberOfShards(3)
                        .numberOfReplicas(1)
                        .refreshInterval(r -> r.time("5s"))
                    )
                    .mappings(m -> m
                        .properties("text", p -> p.text(t -> t
                            .analyzer("standard")
                            .fields("keyword", f -> f.keyword(k -> k))
                        ))
                        .properties("prediction", p -> p.object(o -> o
                            .properties("score", s -> s.float32(f -> f))
                            .properties("label", l -> l.keyword(k -> k))
                        ))
                        .properties("timestamp", p -> p.date(d -> d))
                        .properties("processing_time_ms", p -> p.long_(l -> l))
                    )
                )
            );

            logger.info("推理结果索引模板创建成功");
        } catch (Exception e) {
            logger.error("创建索引模板失败: {}", e.getMessage(), e);
            throw e;
        }
    }

    /**
     * 创建特定的推理索引
     */
    public boolean createInferenceIndex(ElasticsearchClient client, String indexName) {
        try {
            CreateIndexResponse response = client.indices().create(c -> c
                .index(indexName)
                .aliases(indexName + "_alias", a -> a)
            );

            boolean acknowledged = response.acknowledged();
            logger.info("索引 {} 创建{}", indexName, acknowledged ? "成功" : "失败");
            return acknowledged;
        } catch (Exception e) {
            if (e.getMessage().contains("resource_already_exists_exception")) {
                logger.info("索引 {} 已存在", indexName);
                return true;
            }
            logger.error("创建索引失败: {}", e.getMessage(), e);
            return false;
        }
    }
}

4. 代码实现:模型导入与推理

4.1 异常类型定义

首先,定义清晰的异常层次结构:

/**
 * ES机器学习操作基础异常
 */
public class ESMLException extends RuntimeException {
    public ESMLException(String message) {
        super(message);
    }

    public ESMLException(String message, Throwable cause) {
        super(message, cause);
    }
}

/**
 * 模型操作异常
 */
public class ModelOperationException extends ESMLException {
    private final String modelId;
    private final int statusCode;

    public ModelOperationException(String message, String modelId, int statusCode, Throwable cause) {
        super(message, cause);
        this.modelId = modelId;
        this.statusCode = statusCode;
    }

    public ModelOperationException(String message, String modelId, int statusCode) {
        this(message, modelId, statusCode, null);
    }

    public String getModelId() {
        return modelId;
    }

    public int getStatusCode() {
        return statusCode;
    }
}

/**
 * 推理异常
 */
public class InferenceException extends ESMLException {
    private final String pipelineId;

    public InferenceException(String message, String pipelineId, Throwable cause) {
        super(message, cause);
        this.pipelineId = pipelineId;
    }

    public InferenceException(String message, String pipelineId) {
        this(message, pipelineId, null);
    }

    public String getPipelineId() {
        return pipelineId;
    }
}

/**
 * 索引操作异常
 */
public class IndexOperationException extends ESMLException {
    private final String indexName;

    public IndexOperationException(String message, String indexName, Throwable cause) {
        super(message, cause);
        this.indexName = indexName;
    }

    public String getIndexName() {
        return indexName;
    }
}

4.2 ES 客户端工具类

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
import co.elastic.clients.transport.ElasticsearchTransport;
import co.elastic.clients.transport.rest_client.RestClientTransport;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.conn.ssl.TrustSelfSignedStrategy;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.ssl.SSLContexts;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.util.Objects;
import java.util.function.Supplier;

public class ESClientUtil {
    private static final Logger logger = LoggerFactory.getLogger(ESClientUtil.class);

    // 从配置文件或环境变量读取连接信息
    private static final String ES_HOST = System.getProperty("es.host", "localhost");
    private static final int ES_PORT = Integer.parseInt(System.getProperty("es.port", "9200"));
    private static final String ES_PROTOCOL = System.getProperty("es.protocol", "https");
    private static final String ES_USERNAME = System.getProperty("es.username", "elastic");
    private static final String ES_PASSWORD = System.getProperty("es.password", "changeme");
    private static final String ES_CERT_PATH = System.getProperty("es.cert.path");

    // 连接池配置
    private static final int MAX_CONN_TOTAL = 100;
    private static final int MAX_CONN_PER_ROUTE = 30;
    private static final int CONNECTION_TIMEOUT_MS = 5000;
    private static final int SOCKET_TIMEOUT_MS = 60000;

    // 重试配置
    private static final int MAX_RETRIES = 3;
    private static final long RETRY_BACKOFF_MS = 1000;

    /**
     * 创建带连接池的ES客户端
     * @return 配置好的Elasticsearch客户端
     * @throws IOException 如果创建客户端失败
     */
    public static ElasticsearchClient createClient() throws IOException {
        Objects.requireNonNull(ES_HOST, "ES主机地址不能为空");

        RestClientBuilder builder = RestClient.builder(
                new HttpHost(ES_HOST, ES_PORT, ES_PROTOCOL));

        // 配置连接池
        builder.setHttpClientConfigCallback(httpClientBuilder -> {
            httpClientBuilder.setMaxConnTotal(MAX_CONN_TOTAL);
            httpClientBuilder.setMaxConnPerRoute(MAX_CONN_PER_ROUTE);

            // 配置认证
            if (ES_USERNAME != null && !ES_USERNAME.isEmpty()) {
                final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
                credentialsProvider.setCredentials(AuthScope.ANY,
                        new UsernamePasswordCredentials(ES_USERNAME, ES_PASSWORD));
                httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
            }

            // 配置SSL证书
            if (ES_CERT_PATH != null && !ES_CERT_PATH.isEmpty()) {
                try {
                    SSLContext sslContext = buildSSLContext();
                    httpClientBuilder.setSSLContext(sslContext);
                } catch (Exception e) {
                    logger.error("SSL证书配置失败", e);
                }
            }

            return configureHttpClient(httpClientBuilder);
        });

        // 配置请求超时
        builder.setRequestConfigCallback(requestConfigBuilder ->
            requestConfigBuilder
                .setConnectTimeout(CONNECTION_TIMEOUT_MS)
                .setSocketTimeout(SOCKET_TIMEOUT_MS)
        );

        // 创建客户端
        try {
            ElasticsearchTransport transport = new RestClientTransport(
                    builder.build(), new JacksonJsonpMapper());
            return new ElasticsearchClient(transport);
        } catch (Exception e) {
            logger.error("创建ES客户端失败", e);
            throw new IOException("创建ES客户端失败: " + e.getMessage(), e);
        }
    }

    /**
     * 配置HTTP客户端
     */
    private static HttpAsyncClientBuilder configureHttpClient(HttpAsyncClientBuilder httpClientBuilder) {
        return httpClientBuilder
            .setDefaultRequestConfig(
                RequestConfig.custom()
                    .setConnectTimeout(CONNECTION_TIMEOUT_MS)
                    .setSocketTimeout(SOCKET_TIMEOUT_MS)
                    .build()
            );
    }

    /**
     * 构建SSL上下文
     */
    private static SSLContext buildSSLContext() throws Exception {
        Path certPath = Path.of(ES_CERT_PATH);

        if (!Files.exists(certPath)) {
            throw new IllegalArgumentException("证书文件不存在: " + certPath);
        }

        CertificateFactory factory = CertificateFactory.getInstance("X.509");
        Certificate trustedCa;

        try (InputStream is = Files.newInputStream(certPath)) {
            trustedCa = factory.generateCertificate(is);
        }

        KeyStore trustStore = KeyStore.getInstance("pkcs12");
        trustStore.load(null, null);
        trustStore.setCertificateEntry("ca", trustedCa);

        return SSLContexts.custom()
                .loadTrustMaterial(trustStore, new TrustSelfSignedStrategy())
                .build();
    }

    /**
     * 关闭ES传输层
     */
    public static void closeTransport(ElasticsearchTransport transport) {
        if (transport != null) {
            try {
                transport.close();
            } catch (IOException e) {
                logger.warn("关闭ES传输层失败", e);
            }
        }
    }

    /**
     * 执行带重试的ES操作
     * @param operation 要执行的操作
     * @return 操作结果
     * @throws IOException 如果操作最终失败
     */
    public static <T> T executeWithRetry(Supplier<T> operation) throws IOException {
        int attempts = 0;
        IOException lastException = null;

        while (attempts < MAX_RETRIES) {
            try {
                return operation.get();
            } catch (IOException e) {
                if (isRetryableException(e)) {
                    lastException = e;
                    attempts++;

                    if (attempts < MAX_RETRIES) {
                        long backoffTime = RETRY_BACKOFF_MS * attempts;
                        logger.warn("ES操作失败,将在{}ms后重试(尝试{}/{}): {}",
                                backoffTime, attempts, MAX_RETRIES, e.getMessage());
                        try {
                            Thread.sleep(backoffTime);
                        } catch (InterruptedException ie) {
                            Thread.currentThread().interrupt();
                            throw new IOException("重试等待被中断", ie);
                        }
                    }
                } else {
                    // 不可重试的异常直接抛出
                    logger.error("遇到不可重试的ES异常", e);
                    throw e;
                }
            }
        }

        logger.error("ES操作在{}次尝试后仍然失败", MAX_RETRIES);
        throw lastException;
    }

    /**
     * 判断异常是否可以重试
     */
    private static boolean isRetryableException(IOException e) {
        // 网络相关异常通常可以重试
        if (e instanceof java.net.SocketTimeoutException ||
            e instanceof java.net.ConnectException) {
            return true;
        }

        // 特定HTTP状态码也可以重试
        String message = e.getMessage();
        if (message != null && (
                message.contains("429") || // Too Many Requests
                message.contains("503") || // Service Unavailable
                message.contains("507")    // Insufficient Storage
            )) {
            return true;
        }

        return false;
    }

    /**
     * 检查ES集群健康状态
     */
    public static boolean isClusterHealthy() {
        try (var client = createClient()) {
            var response = client.cluster().health();
            String status = response.status().toString();
            return "green".equals(status) || "yellow".equals(status);
        } catch (Exception e) {
            logger.error("检查集群健康状态失败: {}", e.getMessage(), e);
            return false;
        }
    }
}

4.3 模型上传接口与实现

首先定义模型上传的接口,遵循依赖倒置原则:

/**
 * 模型上传器接口
 */
public interface ModelUploader {
    /**
     * 上传模型到Elasticsearch
     * @param modelId 模型ID
     * @param modelPath 模型文件路径
     * @param description 模型描述
     * @param tags 模型标签
     * @throws ModelOperationException 如果上传失败
     */
    void uploadModel(String modelId, Path modelPath, String description, String... tags);

    /**
     * 检查模型是否存在
     * @param modelId 模型ID
     * @return 模型是否存在
     */
    boolean modelExists(String modelId);

    /**
     * 创建模型别名
     * @param modelId 模型ID
     * @param aliasName 别名
     */
    void createModelAlias(String modelId, String aliasName);
}

然后是接口实现:

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.ml.PutTrainedModelRequest;
import co.elastic.clients.elasticsearch.ml.TrainedModelConfig;
import co.elastic.clients.elasticsearch._types.ElasticsearchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Base64;
import java.util.Objects;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

public class ModelUploadService implements ModelUploader {
    private static final Logger logger = LoggerFactory.getLogger(ModelUploadService.class);
    private static final int CHUNK_SIZE = 1024 * 1024;  // 1MB
    private static final int MAX_WAIT_SECONDS = 120;    // 等待模型加载的最大时间

    // 并发控制锁
    private final ReadWriteLock modelUpdateLock = new ReentrantReadWriteLock();

    /**
     * 上传ONNX模型到Elasticsearch
     */
    @Override
    public void uploadModel(String modelId, Path modelPath, String description, String... tags) {
        // 参数验证
        Objects.requireNonNull(modelId, "模型ID不能为空");
        Objects.requireNonNull(modelPath, "模型路径不能为空");
        if (!Files.exists(modelPath)) {
            throw new IllegalArgumentException("模型文件不存在: " + modelPath);
        }

        // 获取写锁,确保模型更新时的独占访问
        modelUpdateLock.writeLock().lock();
        try (var client = ESClientUtil.createClient()) {
            // 检查模型文件大小
            long fileSize = Files.size(modelPath);
            logger.info("开始上传模型 {}, 文件大小: {} 字节", modelId, fileSize);

            // 根据大小选择上传方式
            if (fileSize > 10 * 1024 * 1024) {  // 大于10MB使用分块上传
                uploadLargeModel(client, modelId, modelPath, description, tags);
            } else {
                uploadSmallModel(client, modelId, modelPath, description, tags);
            }

            // 验证模型上传状态
            verifyModelStatus(client, modelId);

            logger.info("模型 {} 上传成功", modelId);
        } catch (ElasticsearchException e) {
            logger.error("ES操作异常: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new ModelOperationException("上传模型失败: " + e.getMessage(), modelId, e.status(), e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new ModelOperationException("读取或上传模型失败", modelId, 500, e);
        } finally {
            modelUpdateLock.writeLock().unlock();
        }
    }

    /**
     * 上传小型模型(一次性上传)
     */
    private void uploadSmallModel(
            ElasticsearchClient client,
            String modelId,
            Path modelPath,
            String description,
            String[] tags) throws IOException {

        // 读取模型文件并编码
        String modelBase64;
        try (InputStream is = Files.newInputStream(modelPath)) {
            byte[] modelBytes = is.readAllBytes();
            modelBase64 = Base64.getEncoder().encodeToString(modelBytes);
        }

        try {
            // 创建模型配置并上传
            client.ml().putTrainedModel(PutTrainedModelRequest.of(builder ->
                builder
                    .modelId(modelId)
                    .inferenceConfig(ic -> ic.onnx(onnx -> onnx))
                    .modelType("onnx")
                    .tags(Arrays.asList(tags))
                    .description(description)
                    .definition(def -> def.modelBytes(modelBase64))
            ));
        } catch (ElasticsearchException e) {
            if (e.status() == 413) {  // 请求实体太大
                logger.warn("模型过大,尝试分块上传");
                uploadLargeModel(client, modelId, modelPath, description, tags);
            } else {
                throw e;
            }
        }
    }

    /**
     * 上传大型模型(分块上传)
     */
    private void uploadLargeModel(
            ElasticsearchClient client,
            String modelId,
            Path modelPath,
            String description,
            String[] tags) throws IOException {

        // 创建模型配置(不包含模型内容)
        createModelConfig(client, modelId, description, tags);

        // 分块上传模型内容
        uploadModelChunks(client, modelId, modelPath);
    }

    /**
     * 创建模型配置
     */
    private void createModelConfig(
            ElasticsearchClient client,
            String modelId,
            String description,
            String[] tags) throws IOException {

        client.ml().putTrainedModel(builder ->
            builder
                .modelId(modelId)
                .inferenceConfig(ic -> ic.onnx(onnx -> onnx))
                .modelType("onnx")
                .tags(Arrays.asList(tags))
                .description(description)
        );

        logger.info("创建模型 {} 配置成功", modelId);
    }

    /**
     * 分块上传模型内容
     */
    private void uploadModelChunks(
            ElasticsearchClient client,
            String modelId,
            Path modelPath) throws IOException {

        // 获取文件大小
        long fileSize = Files.size(modelPath);
        int totalParts = (int) Math.ceil((double) fileSize / CHUNK_SIZE);

        try (InputStream is = Files.newInputStream(modelPath)) {
            byte[] buffer = new byte[CHUNK_SIZE];
            int partNum = 0;
            int bytesRead;
            long totalBytesRead = 0;

            // 读取并上传每个分块
            while ((bytesRead = is.read(buffer)) != -1) {
                // 如果读取的字节数小于缓冲区大小,创建一个刚好大小的新数组
                byte[] chunk = bytesRead == buffer.length ? buffer : Arrays.copyOf(buffer, bytesRead);
                String base64Chunk = Base64.getEncoder().encodeToString(chunk);

                ESClientUtil.executeWithRetry(() -> {
                    client.ml().putTrainedModelDefinitionPart(d -> d
                        .modelId(modelId)
                        .part(partNum)
                        .definitionLength(fileSize)
                        .totalParts(totalParts)
                        .definition(base64Chunk)
                    );
                    return null;
                });

                totalBytesRead += bytesRead;
                int progressPercent = (int)((totalBytesRead * 100) / fileSize);
                logger.info("模型 {} 上传进度: {}/{} 块 ({}%)",
                        modelId, partNum + 1, totalParts, progressPercent);

                partNum++;
            }
        }
    }

    /**
     * 验证模型状态
     */
    private void verifyModelStatus(ElasticsearchClient client, String modelId) {
        try {
            // 等待模型加载完成
            int attempts = 0;
            boolean modelReady = false;
            int maxAttempts = MAX_WAIT_SECONDS; // 最多等待2分钟

            while (attempts < maxAttempts && !modelReady) {
                var response = client.ml().getTrainedModels(m -> m.modelId(modelId));
                if (response.trainedModelConfigs().isEmpty()) {
                    throw new ModelOperationException("模型未找到", modelId, 404);
                }

                var modelInfo = response.trainedModelConfigs().get(0);
                String modelState = modelInfo.modelState();

                if ("started".equals(modelState)) {
                    modelReady = true;
                    logger.info("模型 {} 已加载就绪", modelId);
                } else if ("starting".equals(modelState)) {
                    logger.info("模型 {} 当前状态: 启动中, 等待加载... (尝试 {}/{})",
                            modelId, attempts + 1, maxAttempts);
                    Thread.sleep(1000);  // 等待1秒再检查
                    attempts++;
                } else if ("failed".equals(modelState)) {
                    throw new ModelOperationException("模型加载失败: " + modelInfo.failure_reason(),
                            modelId, 500);
                } else {
                    logger.info("模型 {} 当前状态: {}, 等待加载... (尝试 {}/{})",
                            modelId, modelState, attempts + 1, maxAttempts);
                    Thread.sleep(1000);
                    attempts++;
                }
            }

            if (!modelReady) {
                logger.warn("模型 {} 未能在{}秒内加载完成", modelId, MAX_WAIT_SECONDS);
                throw new ModelOperationException(
                        "模型加载超时,请稍后检查状态", modelId, 408);
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            logger.warn("等待模型加载过程被中断", e);
            throw new ModelOperationException("模型验证被中断", modelId, 500, e);
        } catch (IOException e) {
            logger.warn("验证模型状态失败: {}", e.getMessage(), e);
            throw new ModelOperationException("验证模型状态失败", modelId, 500, e);
        } catch (ModelOperationException e) {
            throw e;
        } catch (Exception e) {
            logger.warn("验证模型状态失败: {}", e.getMessage(), e);
            throw new ModelOperationException("验证模型状态失败", modelId, 500, e);
        }
    }

    /**
     * 检查模型是否存在
     */
    @Override
    public boolean modelExists(String modelId) {
        Objects.requireNonNull(modelId, "模型ID不能为空");

        // 获取读锁,允许并发读取
        modelUpdateLock.readLock().lock();
        try (var client = ESClientUtil.createClient()) {
            var response = client.ml().getTrainedModels(m -> m.modelId(modelId));
            return !response.trainedModelConfigs().isEmpty();
        } catch (ElasticsearchException e) {
            if (e.status() == 404) {
                return false;
            }
            logger.error("检查模型存在性失败: {}", e.getMessage(), e);
            throw new ModelOperationException("检查模型失败", modelId, e.status(), e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new ModelOperationException("检查模型IO异常", modelId, 500, e);
        } finally {
            modelUpdateLock.readLock().unlock();
        }
    }

    /**
     * 创建模型别名
     */
    @Override
    public void createModelAlias(String modelId, String aliasName) {
        Objects.requireNonNull(modelId, "模型ID不能为空");
        Objects.requireNonNull(aliasName, "别名不能为空");

        modelUpdateLock.writeLock().lock();
        try (var client = ESClientUtil.createClient()) {
            client.ml().putTrainedModelAlias(a -> a
                .modelId(modelId)
                .modelAlias(aliasName)
                .reassign(true)
            );
            logger.info("为模型 {} 创建别名 {}", modelId, aliasName);
        } catch (ElasticsearchException e) {
            logger.error("创建模型别名失败: {}", e.getMessage(), e);
            throw new ModelOperationException("创建模型别名失败", modelId, e.status(), e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new ModelOperationException("创建模型别名IO异常", modelId, 500, e);
        } finally {
            modelUpdateLock.writeLock().unlock();
        }
    }
}

4.4 推理管道接口与实现

同样,先定义管道处理的接口:

/**
 * 推理管道服务接口
 */
public interface InferencePipelineManager {
    /**
     * 创建推理管道
     * @param pipelineId 管道ID
     * @param modelId 模型ID
     * @param sourceField 源字段名
     * @param targetField 目标字段名
     * @param description 管道描述
     */
    void createInferencePipeline(String pipelineId, String modelId,
                                String sourceField, String targetField,
                                String description);

    /**
     * 确保管道存在,不存在则创建
     */
    void ensurePipelineExists(String pipelineId, String modelId,
                             String sourceField, String targetField,
                             String description);

    /**
     * 更新现有管道使用新模型
     */
    void updatePipelineModel(String pipelineId, String newModelId);
}

接口实现:

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.ingest.Processor;
import co.elastic.clients.elasticsearch.ingest.ProcessorsBuilder;
import co.elastic.clients.elasticsearch.ingest.PutPipelineRequest;
import co.elastic.clients.elasticsearch._types.ElasticsearchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.annotation.Scheduled;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

public class InferencePipelineService implements InferencePipelineManager {
    private static final Logger logger = LoggerFactory.getLogger(InferencePipelineService.class);

    // 管道缓存,避免重复创建
    private final ConcurrentHashMap<String, Long> pipelineCache = new ConcurrentHashMap<>();
    private final ReadWriteLock pipelineLock = new ReentrantReadWriteLock();

    // 缓存过期时间
    private static final long CACHE_TTL_MS = 3600000; // 1小时

    // 内存控制参数
    private static final int DEFAULT_INFERENCE_THREADS = 1;
    private static final int DEFAULT_NUM_TOP_CLASSES = 2;

    /**
     * 创建推理管道
     */
    @Override
    public void createInferencePipeline(
            String pipelineId,
            String modelId,
            String sourceField,
            String targetField,
            String description) {

        // 参数验证
        Objects.requireNonNull(pipelineId, "管道ID不能为空");
        Objects.requireNonNull(modelId, "模型ID不能为空");
        Objects.requireNonNull(sourceField, "源字段不能为空");
        Objects.requireNonNull(targetField, "目标字段不能为空");

        pipelineLock.writeLock().lock();
        try (var client = ESClientUtil.createClient()) {
            // 创建推理处理器
            Processor inferenceProcessor = new ProcessorsBuilder()
                .inference(ip -> ip
                    .modelId(modelId)
                    .targetField(targetField)
                    .fieldMap(Map.of(sourceField, "input_text"))
                    .inferenceConfig(ic -> ic
                        .classification(c -> c
                            .numTopClasses(DEFAULT_NUM_TOP_CLASSES)
                        )
                    )
                    .numberOfInferenceThreads(DEFAULT_INFERENCE_THREADS)
                )
                .build();

            // 创建包含推理处理器的管道
            client.ingest().putPipeline(PutPipelineRequest.of(builder ->
                builder
                    .id(pipelineId)
                    .description(description != null ? description : "推理管道 " + pipelineId)
                    .processors(inferenceProcessor)
            ));

            logger.info("推理管道 {} 创建成功, 使用模型 {}", pipelineId, modelId);

            // 验证管道
            verifyPipeline(client, pipelineId);

            // 更新缓存
            pipelineCache.put(pipelineId, System.currentTimeMillis());
        } catch (ElasticsearchException e) {
            logger.error("创建推理管道失败: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new InferenceException("创建推理管道失败: " + e.getMessage(), pipelineId, e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new InferenceException("创建推理管道IO异常", pipelineId, e);
        } finally {
            pipelineLock.writeLock().unlock();
        }
    }

    /**
     * 验证管道是否正常工作
     */
    private void verifyPipeline(ElasticsearchClient client, String pipelineId) {
        try {
            // 简单测试管道
            Map<String, Object> testDoc = Map.of(
                "text", "这是一个测试文本,用于验证推理管道是否正常工作。"
            );

            var response = client.ingest().simulate(s -> s
                .pipeline(p -> p.id(pipelineId))
                .docs(d -> d.doc(testDoc))
            );

            if (response.docs().isEmpty()) {
                throw new InferenceException("管道测试返回空结果", pipelineId);
            }

            var processedDoc = response.docs().get(0).doc().source();
            logger.info("管道 {} 测试成功", pipelineId);
        } catch (Exception e) {
            logger.warn("管道验证失败: {}", e.getMessage(), e);
            throw new InferenceException("管道验证失败", pipelineId, e);
        }
    }

    /**
     * 检查管道是否存在,不存在则创建
     * 用于确保推理前管道可用
     */
    @Override
    public void ensurePipelineExists(
            String pipelineId,
            String modelId,
            String sourceField,
            String targetField,
            String description) {

        // 先检查缓存
        if (pipelineCache.containsKey(pipelineId)) {
            return;
        }

        // 获取读锁检查
        pipelineLock.readLock().lock();
        try (var client = ESClientUtil.createClient()) {
            var response = client.ingest().getPipeline(g -> g.id(pipelineId));
            if (!response.result().isEmpty()) {
                // 管道已存在,更新缓存
                pipelineCache.put(pipelineId, System.currentTimeMillis());
                return;
            }
        } catch (Exception e) {
            // 忽略检查错误,尝试创建
            logger.debug("检查管道存在性出错,尝试创建: {}", e.getMessage());
        } finally {
            pipelineLock.readLock().unlock();
        }

        // 管道不存在,创建新管道
        createInferencePipeline(pipelineId, modelId, sourceField, targetField, description);
    }

    /**
     * 更新现有管道使用新模型
     */
    @Override
    public void updatePipelineModel(String pipelineId, String newModelId) {
        Objects.requireNonNull(pipelineId, "管道ID不能为空");
        Objects.requireNonNull(newModelId, "新模型ID不能为空");

        pipelineLock.writeLock().lock();
        try (var client = ESClientUtil.createClient()) {
            // 获取现有管道
            var response = client.ingest().getPipeline(g -> g.id(pipelineId));
            if (response.result().isEmpty()) {
                throw new InferenceException("管道不存在,无法更新", pipelineId);
            }

            // 获取当前管道配置
            var pipeline = response.result().get(pipelineId);
            var description = pipeline.description();
            var processors = pipeline.processors();

            // 查找并替换推理处理器中的模型ID
            boolean modelUpdated = false;
            for (var processor : processors) {
                if (processor.inference() != null) {
                    processor.inference().modelId(newModelId);
                    modelUpdated = true;
                }
            }

            if (!modelUpdated) {
                throw new InferenceException("管道中未找到推理处理器", pipelineId);
            }

            // 更新管道
            client.ingest().putPipeline(p -> p
                .id(pipelineId)
                .description(description)
                .processors(processors)
            );

            logger.info("管道 {} 已更新使用新模型 {}", pipelineId, newModelId);

            // 验证更新后的管道
            verifyPipeline(client, pipelineId);

            // 更新缓存
            pipelineCache.put(pipelineId, System.currentTimeMillis());
        } catch (ElasticsearchException e) {
            logger.error("更新管道模型失败: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new InferenceException("更新管道模型失败", pipelineId, e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new InferenceException("更新管道IO异常", pipelineId, e);
        } finally {
            pipelineLock.writeLock().unlock();
        }
    }

    /**
     * 定期清理过期缓存
     */
    @Scheduled(fixedRate = 3600000) // 每小时执行一次
    public void cleanupCache() {
        long now = System.currentTimeMillis();
        int removedCount = 0;

        for (Map.Entry<String, Long> entry : pipelineCache.entrySet()) {
            if (now - entry.getValue() > CACHE_TTL_MS) {
                pipelineCache.remove(entry.getKey());
                removedCount++;
            }
        }

        if (removedCount > 0) {
            logger.info("缓存清理完成,移除了{}个过期管道记录", removedCount);
        }
    }
}

4.5 系统初始化与配置

添加一个系统初始化组件,确保系统启动时完成必要的设置:

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;

@Component
public class ESMLSystemInitializer {
    private static final Logger logger = LoggerFactory.getLogger(ESMLSystemInitializer.class);

    private final IndexManager indexManager;
    private final ElasticsearchProperties esProperties;

    @Autowired
    public ESMLSystemInitializer(
            IndexManager indexManager,
            ElasticsearchProperties esProperties) {
        this.indexManager = indexManager;
        this.esProperties = esProperties;
    }

    @PostConstruct
    public void initialize() {
        logger.info("初始化ES ML系统...");

        // 异步初始化,避免阻塞应用启动
        CompletableFuture.runAsync(() -> {
            try {
                // 1. 验证集群连接
                if (!ESClientUtil.isClusterHealthy()) {
                    logger.warn("Elasticsearch集群状态不健康,初始化可能不完整");
                }

                // 2. 创建索引模板
                try (var client = ESClientUtil.createClient()) {
                    indexManager.createInferenceIndexTemplate(client);

                    // 3. 创建必要的索引
                    for (String indexName : esProperties.getRequiredIndices()) {
                        indexManager.createInferenceIndex(client, indexName);
                    }
                }

                logger.info("ES ML系统初始化完成");
            } catch (IOException e) {
                logger.error("初始化ES ML系统失败: {}", e.getMessage(), e);
            }
        });
    }
}

5. 推理服务设计与实现

接下来,我们设计更具体的推理接口和实现:

/**
 * 单文本推理接口
 */
public interface SingleInference {
    /**
     * 执行单文本推理
     * @param indexName 索引名称
     * @param pipelineId 管道ID
     * @param modelId 模型ID
     * @param text 输入文本
     * @return 推理结果
     */
    Map<String, Object> inferSingle(String indexName, String pipelineId,
                                   String modelId, String text);
}

/**
 * 批量推理接口
 */
public interface BatchInference {
    /**
     * 执行批量文本推理
     * @param indexName 索引名称
     * @param pipelineId 管道ID
     * @param modelId 模型ID
     * @param texts 输入文本列表
     * @return 推理结果列表
     */
    List<Map<String, Object>> inferBatch(String indexName, String pipelineId,
                                        String modelId, List<String> texts);
}

/**
 * 链式推理接口
 */
public interface ChainedInference {
    /**
     * 执行多模型链式推理
     * @param modelChain 模型链配置
     * @param text 输入文本
     * @return 推理结果
     */
    Map<String, Object> chainedInference(List<ModelConfig> modelChain, String text);

    /**
     * 模型配置类
     */
    class ModelConfig {
        private final String modelId;
        private final String pipelineId;
        private final String indexName;
        private final String sourceField;
        private final String targetField;

        public ModelConfig(String modelId, String pipelineId, String indexName,
                          String sourceField, String targetField) {
            this.modelId = modelId;
            this.pipelineId = pipelineId;
            this.indexName = indexName;
            this.sourceField = sourceField;
            this.targetField = targetField;
        }

        // Getters
        public String getModelId() { return modelId; }
        public String getPipelineId() { return pipelineId; }
        public String getIndexName() { return indexName; }
        public String getSourceField() { return sourceField; }
        public String getTargetField() { return targetField; }
    }
}

/**
 * 完整推理服务接口
 */
public interface InferenceService extends SingleInference, BatchInference, ChainedInference {
    // 组合上述三个接口
}

实现上述接口的推理服务:

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.core.*;
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
import co.elastic.clients.elasticsearch._types.ElasticsearchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

@Service
public class ESInferenceService implements InferenceService {
    private static final Logger logger = LoggerFactory.getLogger(ESInferenceService.class);

    private final ModelUploader modelUploader;
    private final InferencePipelineManager pipelineManager;

    // 批处理控制参数
    private static final int MAX_BATCH_SIZE = 100;
    private static final int DEFAULT_BATCH_SIZE = 20;

    @Autowired
    public ESInferenceService(ModelUploader modelUploader,
                             InferencePipelineManager pipelineManager) {
        this.modelUploader = modelUploader;
        this.pipelineManager = pipelineManager;
    }

    /**
     * 单条文本推理
     */
    @Override
    public Map<String, Object> inferSingle(
            String indexName,
            String pipelineId,
            String modelId,
            String text) {

        // 参数验证
        Objects.requireNonNull(indexName, "索引名称不能为空");
        Objects.requireNonNull(pipelineId, "管道ID不能为空");
        Objects.requireNonNull(modelId, "模型ID不能为空");
        Objects.requireNonNull(text, "输入文本不能为空");

        // 确保管道存在
        pipelineManager.ensurePipelineExists(
                pipelineId, modelId, "text", "prediction", "推理管道");

        try (var client = ESClientUtil.createClient()) {
            // 准备文档
            String documentId = generateDocumentId();
            Map<String, Object> document = new HashMap<>();
            document.put("text", text);
            document.put("timestamp", System.currentTimeMillis());

            // 记录处理开始时间
            long startTime = System.currentTimeMillis();

            // 使用推理管道索引文档
            IndexResponse response = ESClientUtil.executeWithRetry(() ->
                client.index(IndexRequest.of(builder ->
                    builder
                        .index(indexName)
                        .id(documentId)
                        .pipeline(pipelineId)
                        .document(document)
                ))
            );

            if (!response.result().name().contains("CREATED") &&
                !response.result().name().contains("UPDATED")) {
                throw new InferenceException(
                        "索引文档失败: " + response.result().name(), pipelineId);
            }

            // 查询推理结果
            GetResponse<Map> getResponse = ESClientUtil.executeWithRetry(() ->
                client.get(g -> g
                        .index(indexName)
                        .id(documentId),
                        Map.class)
            );

            if (!getResponse.found()) {
                throw new InferenceException("推理后文档未找到", pipelineId);
            }

            // 计算处理时间
            long processingTime = System.currentTimeMillis() - startTime;
            Map<String, Object> result = getResponse.source();
            result.put("processing_time_ms", processingTime);

            logger.debug("单文本推理完成,处理时间: {}ms", processingTime);
            return result;
        } catch (ElasticsearchException e) {
            logger.error("ES推理异常: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new InferenceException("推理失败: " + e.getMessage(), pipelineId, e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new InferenceException("推理IO异常", pipelineId, e);
        }
    }

    /**
     * 批量文本推理
     */
    @Override
    public List<Map<String, Object>> inferBatch(
            String indexName,
            String pipelineId,
            String modelId,
            List<String> texts) {

        // 检查输入参数
        Objects.requireNonNull(indexName, "索引名称不能为空");
        Objects.requireNonNull(pipelineId, "管道ID不能为空");
        Objects.requireNonNull(modelId, "模型ID不能为空");

        if (texts == null || texts.isEmpty()) {
            return Collections.emptyList();
        }

        // 确保管道存在
        pipelineManager.ensurePipelineExists(
                pipelineId, modelId, "text", "prediction", "推理管道");

        // 大批量数据分批处理
        if (texts.size() > MAX_BATCH_SIZE) {
            return processBatchesInChunks(indexName, pipelineId, texts);
        }

        try (var client = ESClientUtil.createClient()) {
            // 准备批量请求
            List<String> documentIds = new ArrayList<>(texts.size());
            var bulkRequest = client.bulk().builder();

            // 记录处理开始时间
            long startTime = System.currentTimeMillis();

            // 添加所有文档到批量请求
            for (String text : texts) {
                String documentId = generateDocumentId();
                documentIds.add(documentId);

                Map<String, Object> document = new HashMap<>();
                document.put("text", text);
                document.put("timestamp", System.currentTimeMillis());

                bulkRequest.operations(op -> op
                    .index(idx -> idx
                        .index(indexName)
                        .id(documentId)
                        .document(document)
                    )
                );
            }

            // 执行批量请求,指定推理管道
            BulkResponse bulkResponse = ESClientUtil.executeWithRetry(() ->
                bulkRequest.pipeline(pipelineId).build().send()
            );

            // 检查批量操作结果
            if (bulkResponse.errors()) {
                logger.warn("批量推理部分失败");
                for (BulkResponseItem item : bulkResponse.items()) {
                    if (item.error() != null) {
                        logger.error("文档 {} 处理失败: {}",
                                item.id(), item.error().reason());
                    }
                }
            }

            // 批量获取结果文档
            MgetResponse<Map> response = ESClientUtil.executeWithRetry(() ->
                client.mget(m -> m
                    .index(indexName)
                    .ids(documentIds),
                    Map.class)
            );

            // 计算总处理时间
            long totalProcessingTime = System.currentTimeMillis() - startTime;
            long avgProcessingTime = totalProcessingTime / texts.size();

            // 提取结果
            List<Map<String, Object>> results = new ArrayList<>(texts.size());
            for (var doc : response.docs()) {
                if (doc.found()) {
                    Map<String, Object> result = doc.source();
                    result.put("processing_time_ms", avgProcessingTime);
                    results.add(result);
                } else {
                    Map<String, Object> errorDoc = new HashMap<>();
                    errorDoc.put("error", "推理后文档未找到");
                    errorDoc.put("id", doc.id());
                    results.add(errorDoc);
                }
            }

            logger.debug("批量推理完成,{}条文本,总耗时: {}ms,平均: {}ms/条",
                    texts.size(), totalProcessingTime, avgProcessingTime);
            return results;
        } catch (ElasticsearchException e) {
            logger.error("ES批量推理异常: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new InferenceException("批量推理失败: " + e.getMessage(), pipelineId, e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new InferenceException("批量推理IO异常", pipelineId, e);
        }
    }

    /**
     * 大批量数据分块处理
     */
    private List<Map<String, Object>> processBatchesInChunks(
            String indexName, String pipelineId, List<String> texts) {

        List<Map<String, Object>> allResults = new ArrayList<>(texts.size());
        List<List<String>> batches = new ArrayList<>();

        // 分割成小批次
        for (int i = 0; i < texts.size(); i += DEFAULT_BATCH_SIZE) {
            int end = Math.min(i + DEFAULT_BATCH_SIZE, texts.size());
            batches.add(texts.subList(i, end));
        }

        // 处理每个批次
        AtomicInteger processedCount = new AtomicInteger(0);
        for (List<String> batch : batches) {
            try {
                List<Map<String, Object>> batchResults =
                    inferBatchInternal(indexName, pipelineId, batch);
                allResults.addAll(batchResults);

                int completed = processedCount.addAndGet(batch.size());
                logger.info("批量推理进度: {}/{} ({}%)",
                        completed, texts.size(), (completed * 100 / texts.size()));
            } catch (Exception e) {
                logger.error("处理批次失败: {}", e.getMessage(), e);
                // 添加错误结果
                for (int i = 0; i < batch.size(); i++) {
                    Map<String, Object> errorDoc = new HashMap<>();
                    errorDoc.put("error", "批次处理失败: " + e.getMessage());
                    errorDoc.put("text", batch.get(i));
                    allResults.add(errorDoc);
                }
            }
        }

        return allResults;
    }

    /**
     * 内部批量推理实现
     * 处理单个批次,确保批次大小合理
     */
    private List<Map<String, Object>> inferBatchInternal(
            String indexName, String pipelineId, List<String> batch) {

        if (batch.size() > MAX_BATCH_SIZE) {
            throw new IllegalArgumentException(
                    "批次大小超过限制: " + batch.size() + " > " + MAX_BATCH_SIZE);
        }

        try (var client = ESClientUtil.createClient()) {
            // 准备批量请求
            List<String> documentIds = new ArrayList<>(batch.size());
            var bulkRequest = client.bulk().builder();

            // 添加所有文档到批量请求
            for (String text : batch) {
                String documentId = generateDocumentId();
                documentIds.add(documentId);

                Map<String, Object> document = new HashMap<>();
                document.put("text", text);
                document.put("timestamp", System.currentTimeMillis());

                bulkRequest.operations(op -> op
                    .index(idx -> idx
                        .index(indexName)
                        .id(documentId)
                        .document(document)
                    )
                );
            }

            // 执行批量请求,指定推理管道
            BulkResponse bulkResponse = ESClientUtil.executeWithRetry(() ->
                bulkRequest.pipeline(pipelineId).build().send()
            );

            // 检查批量操作结果
            checkBulkResponse(bulkResponse);

            // 批量获取结果文档
            MgetResponse<Map> response = ESClientUtil.executeWithRetry(() ->
                client.mget(m -> m
                    .index(indexName)
                    .ids(documentIds),
                    Map.class)
            );

            // 提取结果
            List<Map<String, Object>> results = new ArrayList<>(batch.size());
            for (var doc : response.docs()) {
                if (doc.found()) {
                    results.add(doc.source());
                } else {
                    Map<String, Object> errorDoc = new HashMap<>();
                    errorDoc.put("error", "推理后文档未找到");
                    errorDoc.put("id", doc.id());
                    results.add(errorDoc);
                }
            }

            return results;
        } catch (Exception e) {
            logger.error("批次处理异常: {}", e.getMessage(), e);
            throw new InferenceException("批次处理失败", pipelineId, e);
        }
    }

    /**
     * 检查批量响应错误
     */
    private void checkBulkResponse(BulkResponse response) {
        if (response.errors()) {
            StringBuilder errorMsg = new StringBuilder("批量操作部分失败: ");
            for (BulkResponseItem item : response.items()) {
                if (item.error() != null) {
                    errorMsg.append(item.id())
                           .append("(")
                           .append(item.error().reason())
                           .append("), ");
                }
            }
            logger.warn(errorMsg.toString());
        }
    }

    /**
     * 多模型链式推理
     */
    @Override
    public Map<String, Object> chainedInference(
            List<ModelConfig> modelChain, String text) {

        // 参数验证
        Objects.requireNonNull(modelChain, "模型链不能为空");
        if (modelChain.isEmpty()) {
            throw new IllegalArgumentException("模型链不能为空");
        }
        Objects.requireNonNull(text, "输入文本不能为空");

        Map<String, Object> document = new HashMap<>();
        document.put("original_text", text);
        document.put("text", text);
        document.put("timestamp", System.currentTimeMillis());

        // 按顺序执行每个模型
        for (int i = 0; i < modelChain.size(); i++) {
            ModelConfig config = modelChain.get(i);
            String stageName = "stage_" + (i + 1);

            try {
                // 确保管道存在
                String pipelineId = config.getPipelineId();
                pipelineManager.ensurePipelineExists(
                        pipelineId, config.getModelId(),
                        config.getSourceField(), config.getTargetField(),
                        "链式推理管道" + stageName);

                // 执行当前阶段推理
                document = runPipelineStage(
                        config.getIndexName(), pipelineId, document, stageName);

                // 如果有下一个模型,将当前结果作为下一阶段输入
                if (i < modelChain.size() - 1) {
                    ModelConfig nextConfig = modelChain.get(i + 1);
                    // 从当前输出中提取数据作为下一阶段输入
                    Object nextInput = extractFieldFromPath(
                            document, config.getTargetField());
                    // 转换为字符串或保持其结构,取决于下一个模型需要
                    document.put(nextConfig.getSourceField(), nextInput);
                }
            } catch (Exception e) {
                logger.error("链式推理阶段 {} 失败: {}", stageName, e.getMessage(), e);
                document.put("error_" + stageName, e.getMessage());
                // 链式失败,中断后续处理
                break;
            }
        }

        return document;
    }

    /**
     * 执行单个管道阶段
     */
    private Map<String, Object> runPipelineStage(
            String indexName, String pipelineId,
            Map<String, Object> document, String stageName) throws IOException {

        try (var client = ESClientUtil.createClient()) {
            String documentId = generateDocumentId();

            // 索引文档,应用管道
            IndexResponse response = ESClientUtil.executeWithRetry(() ->
                client.index(IndexRequest.of(builder ->
                    builder
                        .index(indexName)
                        .id(documentId)
                        .pipeline(pipelineId)
                        .document(document)
                ))
            );

            // 获取处理后的文档
            GetResponse<Map> getResponse = ESClientUtil.executeWithRetry(() ->
                client.get(g -> g
                        .index(indexName)
                        .id(documentId),
                        Map.class)
            );

            if (!getResponse.found()) {
                throw new InferenceException("阶段 " + stageName + " 处理后文档未找到", pipelineId);
            }

            // 保留处理阶段标记
            Map<String, Object> result = getResponse.source();
            result.put("_stage", stageName);

            return result;
        }
    }

    /**
     * 从嵌套字段路径中提取数据
     */
    private Object extractFieldFromPath(Map<String, Object> document, String fieldPath) {
        if (fieldPath == null || fieldPath.isEmpty()) {
            return null;
        }

        String[] parts = fieldPath.split("\\.");
        Object current = document;

        for (String part : parts) {
            if (current instanceof Map) {
                current = ((Map<?, ?>) current).get(part);
                if (current == null) {
                    return null;
                }
            } else {
                return null;
            }
        }

        return current;
    }

    /**
     * 生成唯一文档ID
     */
    private String generateDocumentId() {
        return UUID.randomUUID().toString();
    }
}

这部分代码提供了以下内容:

  1. 基本的异常类型定义,确保错误处理的明确性
  2. ES 客户端工具类,带有连接池、重试机制和 SSL 配置
  3. 模型上传服务,支持大文件流式处理
  4. 推理管道管理,包括缓存和过期清理
  5. 系统初始化组件,确保启动时完成必要设置
  6. 多接口推理服务设计,分离单文本、批量和链式推理

在下一部分,我们将继续深入探讨高级功能,包括:

  • 熔断器与故障降级策略
  • 模型优化技术
  • 性能监控
  • Spring Boot 集成
  • 实际应用案例