22-LangChain4J框架设计深度解析

1 阅读14分钟

深入理解LangChain4J的框架设计和核心机制,能够帮助我们更好地使用框架、排查问题,并在需要时进行定制化扩展。

时间:45分钟 | 难度:⭐⭐⭐⭐ | Week 4 Day 22

📋 学习目标

  • 理解LangChain4J的整体架构和模块划分
  • 掌握核心接口的设计思想和职责
  • 深入理解AiServices的动态代理实现机制
  • 了解Tool系统的注解处理和函数调用原理
  • 掌握Memory系统的架构设计
  • 理解RAG系统的完整管道设计
  • 识别框架中应用的设计模式
  • 学会通过SPI扩展框架功能

🏗️ 框架整体架构

LangChain4J采用模块化设计,将功能划分为多个独立的Maven模块:

核心模块组成

langchain4j-parent
├── langchain4j-core          # 核心接口和抽象
│   ├── model interfaces      # ChatLanguageModel, EmbeddingModel等
│   ├── data structures       # ChatMessage, TextSegment等
│   ├── store interfaces      # EmbeddingStore, ChatMemoryStore等
│   └── rag interfaces        # ContentRetriever, QueryTransformer等
│
├── langchain4j               # 高层抽象和组合功能
│   ├── AiServices           # 动态代理服务
│   ├── chain builders       # ConversationalChain, RetrievalChain等
│   ├── memory               # 内存管理实现
│   └── rag                  # RAG增强实现
│
├── langchain4j-open-ai      # OpenAI集成
├── langchain4j-azure-open-ai # Azure OpenAI集成
├── langchain4j-ollama       # Ollama集成
├── langchain4j-pgvector     # PostgreSQL向量存储
├── langchain4j-redis        # Redis向量存储
├── langchain4j-chroma       # Chroma向量数据库
└── langchain4j-spring-boot-starter # Spring Boot自动配置

架构分层

┌─────────────────────────────────────────────────────┐
│           Application Layer (Your Code)              │
│  @RestController, @Service, Business Logic           │
└────────────────────┬────────────────────────────────┘
                     │
┌────────────────────▼────────────────────────────────┐
│        High-Level API (langchain4j)                  │
│  AiServices, ConversationalChain, RetrievalChain     │
└────────────────────┬────────────────────────────────┘
                     │
┌────────────────────▼────────────────────────────────┐
│     Core Abstractions (langchain4j-core)             │
│  Interfaces: ChatLanguageModel, EmbeddingStore, etc. │
└────────────────────┬────────────────────────────────┘
                     │
┌────────────────────▼────────────────────────────────┐
│   Provider Implementations (langchain4j-xxx)         │
│  OpenAI, Ollama, PGVector, Redis, etc.               │
└─────────────────────────────────────────────────────┘

设计优势

  1. 依赖倒置:应用代码依赖抽象接口,不依赖具体实现
  2. 可替换性:可以轻松切换不同的模型提供商或存储后端
  3. 独立演进:各个模块可以独立版本管理和发布
  4. 按需引入:只需引入实际使用的模块,减小依赖体积

🔑 核心接口设计

LangChain4J的核心接口设计遵循单一职责原则,每个接口都有明确的职责边界。

ChatLanguageModel接口

这是所有聊天模型的顶层接口:

package dev.langchain4j.model.chat;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.output.Response;

import java.util.List;

/**
 * 聊天语言模型的核心接口
 */
public interface ChatLanguageModel {

    /**
     * 生成单次响应
     * @param messages 对话消息列表
     * @return 模型响应
     */
    Response<AiMessage> generate(List<ChatMessage> messages);

    /**
     * 生成简单文本响应(便捷方法)
     * @param userMessage 用户消息
     * @return 响应文本
     */
    default String generate(String userMessage) {
        UserMessage message = UserMessage.from(userMessage);
        Response<AiMessage> response = generate(List.of(message));
        return response.content().text();
    }
}

支持流式响应的扩展接口

package dev.langchain4j.model.chat;

import dev.langchain4j.data.message.ChatMessage;

import java.util.List;

/**
 * 支持流式输出的聊天模型
 */
public interface StreamingChatLanguageModel {

    /**
     * 流式生成响应
     * @param messages 对话消息列表
     * @param handler 流式响应处理器
     */
    void generate(List<ChatMessage> messages,
                  StreamingResponseHandler<AiMessage> handler);
}

具体实现示例(OpenAI)

public class OpenAiChatModel implements ChatLanguageModel {

    private final String apiKey;
    private final String modelName;
    private final Double temperature;
    private final Integer maxTokens;

    @Override
    public Response<AiMessage> generate(List<ChatMessage> messages) {
        // 1. 转换消息格式
        List<OpenAiMessage> openAiMessages = messages.stream()
            .map(this::toOpenAiMessage)
            .collect(Collectors.toList());

        // 2. 构建请求
        OpenAiChatCompletionRequest request = OpenAiChatCompletionRequest.builder()
            .model(modelName)
            .messages(openAiMessages)
            .temperature(temperature)
            .maxTokens(maxTokens)
            .build();

        // 3. 调用API
        OpenAiChatCompletionResponse response = openAiClient.chatCompletion(request);

        // 4. 转换响应
        return Response.from(
            AiMessage.from(response.choices().get(0).message().content()),
            response.tokenUsage(),
            response.finishReason()
        );
    }
}

EmbeddingModel接口

用于文本向量化的核心接口:

package dev.langchain4j.model.embedding;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;

import java.util.List;

/**
 * 文本嵌入模型接口
 */
public interface EmbeddingModel {

    /**
     * 将单个文本转换为向量
     */
    Response<Embedding> embed(String text);

    /**
     * 将文本段转换为向量
     */
    Response<Embedding> embed(TextSegment textSegment);

    /**
     * 批量向量化(提高效率)
     */
    Response<List<Embedding>> embedAll(List<TextSegment> textSegments);

    /**
     * 返回向量维度
     */
    int dimension();
}

实现示例

public class OpenAiEmbeddingModel implements EmbeddingModel {

    private final String apiKey;
    private final String modelName; // text-embedding-ada-002

    @Override
    public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
        List<String> texts = textSegments.stream()
            .map(TextSegment::text)
            .collect(Collectors.toList());

        OpenAiEmbeddingRequest request = OpenAiEmbeddingRequest.builder()
            .model(modelName)
            .input(texts)
            .build();

        OpenAiEmbeddingResponse response = openAiClient.embedding(request);

        List<Embedding> embeddings = response.data().stream()
            .map(data -> Embedding.from(data.embedding()))
            .collect(Collectors.toList());

        return Response.from(embeddings, response.usage());
    }

    @Override
    public int dimension() {
        return 1536; // ada-002的维度
    }
}

EmbeddingStore接口

向量存储的抽象接口:

package dev.langchain4j.store.embedding;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;

import java.util.List;

/**
 * 向量存储接口
 */
public interface EmbeddingStore<Embedded> {

    /**
     * 添加单个向量
     * @return 向量ID
     */
    String add(Embedding embedding);

    /**
     * 添加向量及其关联的嵌入对象
     */
    String add(Embedding embedding, Embedded embedded);

    /**
     * 批量添加
     */
    List<String> addAll(List<Embedding> embeddings);

    /**
     * 相似度搜索
     * @param referenceEmbedding 查询向量
     * @param maxResults 最大结果数
     * @param minScore 最小相似度分数
     * @return 搜索结果列表
     */
    List<EmbeddingMatch<Embedded>> findRelevant(
        Embedding referenceEmbedding,
        int maxResults,
        double minScore
    );
}

搜索结果封装

public class EmbeddingMatch<Embedded> {
    private final double score;          // 相似度分数
    private final String embeddingId;    // 向量ID
    private final Embedding embedding;   // 向量本身
    private final Embedded embedded;     // 关联的对象(如TextSegment)

    // 构造器、getter等
}

🪄 AiServices动态代理机制

AiServices是LangChain4J最强大的功能之一,它使用Java动态代理将接口方法转换为LLM调用。

使用示例回顾

interface Assistant {
    @SystemMessage("你是一个友好的助手")
    String chat(String userMessage);
}

Assistant assistant = AiServices.create(Assistant.class, chatModel);
String response = assistant.chat("你好");

动态代理实现原理

package dev.langchain4j.service;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

public class AiServices {

    /**
     * 创建AI服务代理
     */
    public static <T> T create(Class<T> aiServiceClass,
                                ChatLanguageModel chatModel) {
        return (T) Proxy.newProxyInstance(
            aiServiceClass.getClassLoader(),
            new Class[]{aiServiceClass},
            new AiServiceInvocationHandler(chatModel)
        );
    }

    /**
     * 调用处理器
     */
    private static class AiServiceInvocationHandler implements InvocationHandler {

        private final ChatLanguageModel chatModel;
        private final ChatMemory chatMemory;
        private final List<ToolSpecification> tools;

        @Override
        public Object invoke(Object proxy, Method method, Object[] args)
                throws Throwable {

            // 1. 解析方法注解
            ServiceMethodInfo methodInfo = parseMethod(method);

            // 2. 构建消息列表
            List<ChatMessage> messages = buildMessages(methodInfo, args);

            // 3. 如果有tool,添加tool规范
            if (!tools.isEmpty()) {
                messages = enrichWithTools(messages, tools);
            }

            // 4. 调用模型
            Response<AiMessage> response = chatModel.generate(messages);

            // 5. 处理tool调用
            if (response.content().hasToolExecutionRequests()) {
                response = handleToolExecution(response);
            }

            // 6. 更新记忆
            if (chatMemory != null) {
                chatMemory.add(messages);
                chatMemory.add(response.content());
            }

            // 7. 转换返回值
            return convertReturnValue(response, method.getReturnType());
        }
    }
}

方法注解解析

/**
 * 方法元信息
 */
class ServiceMethodInfo {
    String systemMessage;
    String userMessageTemplate;
    Map<String, Integer> variablePositions;

    static ServiceMethodInfo parse(Method method) {
        ServiceMethodInfo info = new ServiceMethodInfo();

        // 解析@SystemMessage
        if (method.isAnnotationPresent(SystemMessage.class)) {
            SystemMessage annotation = method.getAnnotation(SystemMessage.class);
            info.systemMessage = annotation.value();
        }

        // 解析@UserMessage
        if (method.isAnnotationPresent(UserMessage.class)) {
            UserMessage annotation = method.getAnnotation(UserMessage.class);
            info.userMessageTemplate = annotation.value();

            // 解析模板变量
            info.variablePositions = parseTemplateVariables(
                annotation.value(),
                method.getParameters()
            );
        }

        return info;
    }
}

消息构建流程

class MessageBuilder {

    List<ChatMessage> buildMessages(ServiceMethodInfo methodInfo, Object[] args) {
        List<ChatMessage> messages = new ArrayList<>();

        // 1. 添加系统消息
        if (methodInfo.systemMessage != null) {
            messages.add(SystemMessage.from(methodInfo.systemMessage));
        }

        // 2. 从记忆中恢复历史消息
        if (chatMemory != null) {
            messages.addAll(chatMemory.messages());
        }

        // 3. 构建用户消息
        String userMessage = buildUserMessage(methodInfo, args);
        messages.add(UserMessage.from(userMessage));

        return messages;
    }

    String buildUserMessage(ServiceMethodInfo methodInfo, Object[] args) {
        if (methodInfo.userMessageTemplate != null) {
            // 使用模板替换变量
            return replaceVariables(
                methodInfo.userMessageTemplate,
                methodInfo.variablePositions,
                args
            );
        } else {
            // 直接使用第一个参数
            return String.valueOf(args[0]);
        }
    }
}

Builder模式配置

public class AiServices<T> {

    public static <T> AiServicesBuilder<T> builder(Class<T> aiServiceClass) {
        return new AiServicesBuilder<>(aiServiceClass);
    }

    public static class AiServicesBuilder<T> {
        private Class<T> aiServiceClass;
        private ChatLanguageModel chatLanguageModel;
        private StreamingChatLanguageModel streamingChatLanguageModel;
        private ChatMemory chatMemory;
        private ChatMemoryProvider chatMemoryProvider;
        private List<Object> tools = new ArrayList<>();
        private ContentRetriever contentRetriever;

        public AiServicesBuilder<T> chatLanguageModel(
                ChatLanguageModel model) {
            this.chatLanguageModel = model;
            return this;
        }

        public AiServicesBuilder<T> chatMemory(ChatMemory chatMemory) {
            this.chatMemory = chatMemory;
            return this;
        }

        public AiServicesBuilder<T> tools(Object... tools) {
            this.tools.addAll(Arrays.asList(tools));
            return this;
        }

        public T build() {
            // 验证必需配置
            validateConfiguration();

            // 解析工具
            List<ToolSpecification> toolSpecs = parseTools(tools);

            // 创建代理
            return createProxy(
                aiServiceClass,
                chatLanguageModel,
                chatMemory,
                toolSpecs
            );
        }
    }
}

🔧 Tool系统实现原理

Tool系统允许LLM调用Java方法来获取外部数据或执行操作。

@Tool注解处理

package dev.langchain4j.agent.tool;

import java.lang.annotation.*;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Tool {
    /**
     * 工具名称(发送给LLM)
     */
    String name() default "";

    /**
     * 工具描述(帮助LLM理解何时使用)
     */
    String value() default "";
}

参数注解

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
public @interface P {
    /**
     * 参数描述
     */
    String value();
}

工具定义示例

public class WeatherTools {

    @Tool("获取指定城市的当前天气信息")
    public String getCurrentWeather(
        @P("城市名称,如'北京'或'上海'") String city
    ) {
        // 实际应该调用天气API
        return String.format("城市:%s, 温度:22°C, 天气:晴", city);
    }

    @Tool("获取未来几天的天气预报")
    public String getWeatherForecast(
        @P("城市名称") String city,
        @P("预报天数,1-7之间") int days
    ) {
        return String.format("城市:%s, %d天预报:...", city, days);
    }
}

ToolSpecification生成

/**
 * 工具规范(发送给LLM的工具描述)
 */
public class ToolSpecification {
    private String name;
    private String description;
    private ToolParameters parameters;

    /**
     * 从Java方法解析工具规范
     */
    public static ToolSpecification from(Method method) {
        Tool toolAnnotation = method.getAnnotation(Tool.class);

        String name = toolAnnotation.name().isEmpty()
            ? method.getName()
            : toolAnnotation.name();

        String description = toolAnnotation.value();

        // 解析参数
        ToolParameters parameters = parseParameters(method);

        return new ToolSpecification(name, description, parameters);
    }

    private static ToolParameters parseParameters(Method method) {
        Map<String, ToolParameter> properties = new HashMap<>();
        List<String> required = new ArrayList<>();

        Parameter[] parameters = method.getParameters();
        for (int i = 0; i < parameters.length; i++) {
            Parameter param = parameters[i];
            P annotation = param.getAnnotation(P.class);

            String paramName = param.getName();
            String description = annotation != null ? annotation.value() : "";
            String type = mapJavaTypeToJsonType(param.getType());

            properties.put(paramName, new ToolParameter(type, description));
            required.add(paramName);
        }

        return new ToolParameters("object", properties, required);
    }
}

LLM眼中的工具

转换为JSON Schema格式发送给LLM:

{
  "type": "function",
  "function": {
    "name": "getCurrentWeather",
    "description": "获取指定城市的当前天气信息",
    "parameters": {
      "type": "object",
      "properties": {
        "city": {
          "type": "string",
          "description": "城市名称,如'北京'或'上海'"
        }
      },
      "required": ["city"]
    }
  }
}

工具执行流程

class ToolExecutor {

    private final Map<String, ToolExecutionRequest> toolRegistry;

    /**
     * 执行工具调用
     */
    String executeTool(ToolExecutionRequest request) {
        // 1. 查找工具方法
        ToolExecutionRequest toolMethod = toolRegistry.get(request.name());
        if (toolMethod == null) {
            return "Error: Tool not found: " + request.name();
        }

        try {
            // 2. 解析参数
            Object[] args = parseArguments(
                toolMethod.method(),
                request.arguments()
            );

            // 3. 执行方法
            Object result = toolMethod.method().invoke(
                toolMethod.instance(),
                args
            );

            // 4. 返回结果
            return String.valueOf(result);

        } catch (Exception e) {
            return "Error executing tool: " + e.getMessage();
        }
    }

    /**
     * 处理包含工具调用的响应
     */
    Response<AiMessage> handleToolExecution(Response<AiMessage> response) {
        List<ChatMessage> messages = new ArrayList<>();
        messages.add(response.content());

        // 执行所有工具调用
        for (ToolExecutionRequest request :
                response.content().toolExecutionRequests()) {

            String result = executeTool(request);

            messages.add(ToolExecutionResultMessage.from(
                request.id(),
                request.name(),
                result
            ));
        }

        // 再次调用LLM处理工具结果
        return chatLanguageModel.generate(messages);
    }
}

🧠 Memory系统架构

Memory系统负责管理对话历史,是实现上下文感知对话的关键。

ChatMemory接口

package dev.langchain4j.memory;

import dev.langchain4j.data.message.ChatMessage;
import java.util.List;

/**
 * 聊天记忆接口
 */
public interface ChatMemory {

    /**
     * 获取聊天记忆的唯一标识
     */
    Object id();

    /**
     * 添加消息到记忆
     */
    void add(ChatMessage message);

    /**
     * 批量添加消息
     */
    default void add(List<ChatMessage> messages) {
        messages.forEach(this::add);
    }

    /**
     * 获取所有消息
     */
    List<ChatMessage> messages();

    /**
     * 清空记忆
     */
    void clear();
}

MessageWindowChatMemory实现

基于滑动窗口的记忆实现,保留最近N条消息:

public class MessageWindowChatMemory implements ChatMemory {

    private final Object id;
    private final int maxMessages;
    private final LinkedList<ChatMessage> messages;
    private final ChatMemoryStore store; // 可选的持久化存储

    public MessageWindowChatMemory(int maxMessages) {
        this.id = UUID.randomUUID();
        this.maxMessages = maxMessages;
        this.messages = new LinkedList<>();
        this.store = null;
    }

    @Override
    public void add(ChatMessage message) {
        messages.add(message);

        // 超过最大消息数,移除最旧的
        while (messages.size() > maxMessages) {
            messages.removeFirst();
        }

        // 持久化到存储
        if (store != null) {
            store.updateMessages(id, messages);
        }
    }

    @Override
    public List<ChatMessage> messages() {
        return new ArrayList<>(messages);
    }

    @Override
    public void clear() {
        messages.clear();
        if (store != null) {
            store.deleteMessages(id);
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private int maxMessages = 10;
        private Object id;
        private ChatMemoryStore store;

        public Builder maxMessages(int maxMessages) {
            this.maxMessages = maxMessages;
            return this;
        }

        public Builder id(Object id) {
            this.id = id;
            return this;
        }

        public Builder chatMemoryStore(ChatMemoryStore store) {
            this.store = store;
            return this;
        }

        public MessageWindowChatMemory build() {
            return new MessageWindowChatMemory(this);
        }
    }
}

TokenWindowChatMemory实现

基于token数量的记忆实现,确保不超过模型上下文限制:

public class TokenWindowChatMemory implements ChatMemory {

    private final Object id;
    private final int maxTokens;
    private final Tokenizer tokenizer;
    private final LinkedList<ChatMessage> messages;

    @Override
    public void add(ChatMessage message) {
        messages.add(message);

        // 计算当前token总数
        int totalTokens = calculateTotalTokens();

        // 移除旧消息直到token数在限制内
        while (totalTokens > maxTokens && messages.size() > 1) {
            ChatMessage removed = messages.removeFirst();
            totalTokens -= tokenizer.estimateTokenCount(removed.text());
        }
    }

    private int calculateTotalTokens() {
        return messages.stream()
            .mapToInt(msg -> tokenizer.estimateTokenCount(msg.text()))
            .sum();
    }
}

ChatMemoryProvider接口

多用户场景下的记忆提供者:

/**
 * 为不同用户提供独立的聊天记忆
 */
public interface ChatMemoryProvider {

    /**
     * 根据记忆ID获取或创建聊天记忆
     */
    ChatMemory get(Object memoryId);
}

实现示例

public class SimpleChatMemoryProvider implements ChatMemoryProvider {

    private final Map<Object, ChatMemory> memories = new ConcurrentHashMap<>();
    private final Supplier<ChatMemory> chatMemoryFactory;

    public SimpleChatMemoryProvider(Supplier<ChatMemory> factory) {
        this.chatMemoryFactory = factory;
    }

    @Override
    public ChatMemory get(Object memoryId) {
        return memories.computeIfAbsent(memoryId,
            id -> chatMemoryFactory.get());
    }
}

持久化存储接口

package dev.langchain4j.store.memory.chat;

import dev.langchain4j.data.message.ChatMessage;
import java.util.List;

/**
 * 聊天记忆持久化存储
 */
public interface ChatMemoryStore {

    /**
     * 获取指定记忆ID的消息
     */
    List<ChatMessage> getMessages(Object memoryId);

    /**
     * 更新消息
     */
    void updateMessages(Object memoryId, List<ChatMessage> messages);

    /**
     * 删除消息
     */
    void deleteMessages(Object memoryId);
}

📦 RAG系统设计

RAG(检索增强生成)系统通过检索相关文档来增强LLM的回答能力。

ContentRetriever接口

package dev.langchain4j.rag;

import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.content.Content;
import java.util.List;

/**
 * 内容检索器接口
 */
public interface ContentRetriever {

    /**
     * 根据查询检索相关内容
     * @param query 用户查询
     * @return 相关内容列表
     */
    List<Content> retrieve(Query query);
}

默认检索器实现

public class EmbeddingStoreContentRetriever implements ContentRetriever {

    private final EmbeddingStore<TextSegment> embeddingStore;
    private final EmbeddingModel embeddingModel;
    private final int maxResults;
    private final double minScore;

    @Override
    public List<Content> retrieve(Query query) {
        // 1. 将查询向量化
        Embedding queryEmbedding = embeddingModel.embed(query.text()).content();

        // 2. 在向量存储中搜索
        List<EmbeddingMatch<TextSegment>> matches = embeddingStore.findRelevant(
            queryEmbedding,
            maxResults,
            minScore
        );

        // 3. 转换为Content对象
        return matches.stream()
            .map(match -> Content.from(match.embedded().text()))
            .collect(Collectors.toList());
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private EmbeddingStore<TextSegment> embeddingStore;
        private EmbeddingModel embeddingModel;
        private int maxResults = 3;
        private double minScore = 0.7;

        // Builder方法...

        public EmbeddingStoreContentRetriever build() {
            return new EmbeddingStoreContentRetriever(this);
        }
    }
}

RetrievalAugmentor接口

增强检索管道,支持查询转换、内容重排等高级功能:

package dev.langchain4j.rag;

import dev.langchain4j.data.message.UserMessage;
import java.util.List;

/**
 * 检索增强器
 */
public interface RetrievalAugmentor {

    /**
     * 增强用户消息
     * @param userMessage 原始用户消息
     * @return 增强后的用户消息
     */
    UserMessage augment(UserMessage userMessage, Metadata metadata);
}

默认增强器实现

public class DefaultRetrievalAugmentor implements RetrievalAugmentor {

    private final QueryTransformer queryTransformer;
    private final QueryRouter queryRouter;
    private final ContentRetriever contentRetriever;
    private final ContentAggregator contentAggregator;
    private final ContentInjector contentInjector;

    @Override
    public UserMessage augment(UserMessage userMessage, Metadata metadata) {

        // 1. 转换查询(扩展、改写等)
        Query query = Query.from(userMessage.text());
        if (queryTransformer != null) {
            query = queryTransformer.transform(query);
        }

        // 2. 路由到合适的检索器(多知识库场景)
        ContentRetriever retriever = contentRetriever;
        if (queryRouter != null) {
            retriever = queryRouter.route(query);
        }

        // 3. 检索相关内容
        List<Content> contents = retriever.retrieve(query);

        // 4. 聚合内容(去重、合并等)
        if (contentAggregator != null) {
            contents = contentAggregator.aggregate(contents);
        }

        // 5. 注入到用户消息
        if (contentInjector != null) {
            return contentInjector.inject(contents, userMessage);
        }

        // 默认注入方式
        return injectDefault(contents, userMessage);
    }

    private UserMessage injectDefault(List<Content> contents,
                                       UserMessage userMessage) {
        String contextText = contents.stream()
            .map(Content::textSegment)
            .map(TextSegment::text)
            .collect(Collectors.joining("\n\n"));

        String augmentedText = String.format(
            "相关信息:\n%s\n\n用户问题:%s",
            contextText,
            userMessage.text()
        );

        return UserMessage.from(augmentedText);
    }
}

查询转换器

/**
 * 查询转换器(扩展、改写、分解等)
 */
public interface QueryTransformer {
    Query transform(Query query);
}

/**
 * 使用LLM扩展查询
 */
public class QueryExpander implements QueryTransformer {

    private final ChatLanguageModel chatModel;

    @Override
    public Query transform(Query query) {
        String prompt = String.format(
            "将以下查询扩展为更详细的搜索查询:\n%s",
            query.text()
        );

        String expandedQuery = chatModel.generate(prompt);
        return Query.from(expandedQuery);
    }
}

RAG完整管道示例

// 构建完整的RAG系统
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
    // 查询转换
    .queryTransformer(new QueryExpander(chatModel))
    // 内容检索
    .contentRetriever(
        EmbeddingStoreContentRetriever.builder()
            .embeddingStore(embeddingStore)
            .embeddingModel(embeddingModel)
            .maxResults(5)
            .minScore(0.7)
            .build()
    )
    // 内容重排(可选)
    .contentAggregator(new ReRankingContentAggregator(reRanker))
    .build();

// 集成到AI服务
Assistant assistant = AiServices.builder(Assistant.class)
    .chatLanguageModel(chatModel)
    .retrievalAugmentor(retrievalAugmentor)
    .build();

🎨 设计模式应用

LangChain4J框架中应用了多种经典设计模式。

1. Builder模式

几乎所有复杂对象都使用Builder模式构建:

// 链式调用,清晰易读
OpenAiChatModel model = OpenAiChatModel.builder()
    .apiKey(System.getenv("OPENAI_API_KEY"))
    .modelName("gpt-4")
    .temperature(0.7)
    .maxTokens(1000)
    .timeout(Duration.ofSeconds(30))
    .logRequests(true)
    .logResponses(true)
    .build();

优势

  • 参数众多时保持可读性
  • 支持可选参数
  • 提供参数验证
  • 支持不可变对象

2. Strategy模式

不同的检索、转换、注入策略可以互换:

// 策略接口
public interface QueryTransformer {
    Query transform(Query query);
}

// 不同策略实现
public class QueryExpander implements QueryTransformer { }
public class QueryCompressor implements QueryTransformer { }
public class QueryRewriter implements QueryTransformer { }

// 运行时选择策略
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
    .queryTransformer(new QueryExpander(chatModel))  // 可替换
    .build();

3. Template Method模式

定义算法骨架,子类实现具体步骤:

public abstract class BaseContentRetriever implements ContentRetriever {

    @Override
    public final List<Content> retrieve(Query query) {
        // 模板方法定义流程
        Query transformedQuery = preProcess(query);
        List<Content> contents = doRetrieve(transformedQuery);
        return postProcess(contents);
    }

    // 钩子方法,子类可选择性覆盖
    protected Query preProcess(Query query) {
        return query;
    }

    // 抽象方法,子类必须实现
    protected abstract List<Content> doRetrieve(Query query);

    protected List<Content> postProcess(List<Content> contents) {
        return contents;
    }
}

4. Chain of Responsibility模式

多个处理器链式处理请求:

/**
 * 内容过滤器链
 */
public interface ContentFilter {
    List<Content> filter(List<Content> contents);
}

public class ContentFilterChain {
    private final List<ContentFilter> filters;

    public List<Content> filter(List<Content> contents) {
        List<Content> result = contents;
        for (ContentFilter filter : filters) {
            result = filter.filter(result);
        }
        return result;
    }
}

// 使用
ContentFilterChain chain = new ContentFilterChain(List.of(
    new DuplicateRemovalFilter(),
    new LengthFilter(1000),
    new RelevanceFilter(0.8)
));

5. Factory模式

创建不同类型的对象:

/**
 * 记忆工厂
 */
public class ChatMemoryFactory {

    public static ChatMemory createWindowMemory(int maxMessages) {
        return MessageWindowChatMemory.builder()
            .maxMessages(maxMessages)
            .build();
    }

    public static ChatMemory createTokenWindowMemory(int maxTokens) {
        return TokenWindowChatMemory.builder()
            .maxTokens(maxTokens)
            .tokenizer(new OpenAiTokenizer())
            .build();
    }
}

6. Observer模式

流式响应使用观察者模式:

// 观察者接口
public interface StreamingResponseHandler<T> {
    void onNext(String token);
    void onComplete(Response<T> response);
    void onError(Throwable error);
}

// 使用
streamingChatModel.generate(messages, new StreamingResponseHandler<>() {
    @Override
    public void onNext(String token) {
        System.out.print(token);  // 实时输出
    }

    @Override
    public void onComplete(Response<AiMessage> response) {
        System.out.println("\n完成!");
    }

    @Override
    public void onError(Throwable error) {
        System.err.println("错误:" + error.getMessage());
    }
});

7. Adapter模式

适配不同提供商的API:

// 统一接口
public interface ChatLanguageModel {
    Response<AiMessage> generate(List<ChatMessage> messages);
}

// OpenAI适配器
public class OpenAiChatModel implements ChatLanguageModel {
    @Override
    public Response<AiMessage> generate(List<ChatMessage> messages) {
        // 转换为OpenAI格式
        // 调用OpenAI API
        // 转换回统一格式
    }
}

// Ollama适配器
public class OllamaChatModel implements ChatLanguageModel {
    @Override
    public Response<AiMessage> generate(List<ChatMessage> messages) {
        // 转换为Ollama格式
        // 调用Ollama API
        // 转换回统一格式
    }
}

🔌 扩展点和SPI

LangChain4J提供了丰富的扩展点,支持自定义实现。

自定义ChatLanguageModel

public class CustomChatModel implements ChatLanguageModel {

    private final String endpoint;
    private final HttpClient httpClient;

    @Override
    public Response<AiMessage> generate(List<ChatMessage> messages) {
        // 1. 转换消息格式
        String requestBody = convertToCustomFormat(messages);

        // 2. 调用自定义API
        HttpResponse<String> response = httpClient.send(
            HttpRequest.newBuilder()
                .uri(URI.create(endpoint))
                .POST(HttpRequest.BodyPublishers.ofString(requestBody))
                .build(),
            HttpResponse.BodyHandlers.ofString()
        );

        // 3. 解析响应
        return parseResponse(response.body());
    }

    private String convertToCustomFormat(List<ChatMessage> messages) {
        // 实现自定义格式转换
        return "...";
    }

    private Response<AiMessage> parseResponse(String responseBody) {
        // 实现响应解析
        return Response.from(AiMessage.from("..."));
    }
}

自定义EmbeddingStore

public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {

    private final RedisClient redis;
    private final String indexName;

    @Override
    public String add(Embedding embedding, TextSegment embedded) {
        String id = UUID.randomUUID().toString();

        // 存储到Redis
        Map<String, String> hash = new HashMap<>();
        hash.put("text", embedded.text());
        hash.put("embedding", serializeEmbedding(embedding));

        redis.hset(indexName + ":" + id, hash);

        return id;
    }

    @Override
    public List<EmbeddingMatch<TextSegment>> findRelevant(
            Embedding referenceEmbedding,
            int maxResults,
            double minScore) {

        // 使用Redis向量搜索
        String query = String.format(
            "FT.SEARCH %s * => [KNN %d @embedding $vector]",
            indexName, maxResults
        );

        // 执行搜索并转换结果
        return executeSearchAndConvert(query, referenceEmbedding);
    }
}

自定义ContentRetriever

public class HybridContentRetriever implements ContentRetriever {

    private final EmbeddingStoreContentRetriever vectorRetriever;
    private final KeywordSearchRetriever keywordRetriever;
    private final double vectorWeight;
    private final double keywordWeight;

    @Override
    public List<Content> retrieve(Query query) {
        // 1. 向量检索
        List<Content> vectorResults = vectorRetriever.retrieve(query);

        // 2. 关键词检索
        List<Content> keywordResults = keywordRetriever.retrieve(query);

        // 3. 混合排序
        return mergeAndRank(vectorResults, keywordResults);
    }

    private List<Content> mergeAndRank(List<Content> vector,
                                        List<Content> keyword) {
        // 实现混合排序算法(如RRF)
        Map<Content, Double> scores = new HashMap<>();

        // 向量检索分数
        for (int i = 0; i < vector.size(); i++) {
            scores.merge(vector.get(i),
                vectorWeight / (i + 1), Double::sum);
        }

        // 关键词检索分数
        for (int i = 0; i < keyword.size(); i++) {
            scores.merge(keyword.get(i),
                keywordWeight / (i + 1), Double::sum);
        }

        // 按分数排序
        return scores.entrySet().stream()
            .sorted(Map.Entry.<Content, Double>comparingByValue().reversed())
            .map(Map.Entry::getKey)
            .collect(Collectors.toList());
    }
}

自定义Tokenizer

public class CustomTokenizer implements Tokenizer {

    @Override
    public int estimateTokenCount(String text) {
        // 实现token计数逻辑
        // 简单方法:字符数 / 4(英文)
        // 或使用专门的tokenizer库
        return text.length() / 4;
    }

    @Override
    public List<Integer> encode(String text) {
        // 实现文本编码
        return List.of(/* token ids */);
    }

    @Override
    public String decode(List<Integer> tokens) {
        // 实现token解码
        return "...";
    }
}

Spring Boot自动配置扩展

@Configuration
@ConditionalOnClass(ChatLanguageModel.class)
@EnableConfigurationProperties(CustomModelProperties.class)
public class CustomModelAutoConfiguration {

    @Bean
    @ConditionalOnMissingBean
    @ConditionalOnProperty("custom.model.api-key")
    public ChatLanguageModel customChatModel(
            CustomModelProperties properties) {
        return CustomChatModel.builder()
            .apiKey(properties.getApiKey())
            .endpoint(properties.getEndpoint())
            .modelName(properties.getModelName())
            .build();
    }
}

配置类

@ConfigurationProperties(prefix = "custom.model")
public class CustomModelProperties {
    private String apiKey;
    private String endpoint;
    private String modelName;

    // getters and setters
}

💡 实战练习

练习1:实现自定义记忆策略

实现一个基于重要性的记忆管理器,保留重要消息:

/**
 * 任务:实现ImportanceBasedChatMemory
 * - 给每条消息评估重要性分数
 * - 当超过限制时,删除不重要的消息
 * - 系统消息和包含特定关键词的消息优先保留
 */
public class ImportanceBasedChatMemory implements ChatMemory {
    // 你的实现
}

练习2:实现查询路由器

实现根据查询内容路由到不同知识库的路由器:

/**
 * 任务:实现QueryRouter
 * - 分析查询内容
 * - 路由到最合适的ContentRetriever
 * - 支持多知识库并行检索
 */
public class SmartQueryRouter implements QueryRouter {
    // 你的实现
}

练习3:实现自定义Tool执行器

实现支持异步Tool执行的执行器:

/**
 * 任务:实现AsyncToolExecutor
 * - 支持异步执行Tool
 * - 支持Tool执行超时控制
 * - 支持Tool执行重试
 */
public class AsyncToolExecutor {
    // 你的实现
}

练习4:实现内容重排器

使用交叉编码器对检索结果重新排序:

/**
 * 任务:实现ReRankingContentAggregator
 * - 使用重排序模型对结果重新评分
 * - 返回最相关的top-k结果
 */
public class ReRankingContentAggregator implements ContentAggregator {
    // 你的实现
}

最后更新:2026-03-09 字数统计:5,500 字 预计阅读时间:45 分钟