深入理解LangChain4J的框架设计和核心机制,能够帮助我们更好地使用框架、排查问题,并在需要时进行定制化扩展。
时间:45分钟 | 难度:⭐⭐⭐⭐ | Week 4 Day 22
📋 学习目标
- 理解LangChain4J的整体架构和模块划分
- 掌握核心接口的设计思想和职责
- 深入理解AiServices的动态代理实现机制
- 了解Tool系统的注解处理和函数调用原理
- 掌握Memory系统的架构设计
- 理解RAG系统的完整管道设计
- 识别框架中应用的设计模式
- 学会通过SPI扩展框架功能
🏗️ 框架整体架构
LangChain4J采用模块化设计,将功能划分为多个独立的Maven模块:
核心模块组成
langchain4j-parent
├── langchain4j-core # 核心接口和抽象
│ ├── model interfaces # ChatLanguageModel, EmbeddingModel等
│ ├── data structures # ChatMessage, TextSegment等
│ ├── store interfaces # EmbeddingStore, ChatMemoryStore等
│ └── rag interfaces # ContentRetriever, QueryTransformer等
│
├── langchain4j # 高层抽象和组合功能
│ ├── AiServices # 动态代理服务
│ ├── chain builders # ConversationalChain, RetrievalChain等
│ ├── memory # 内存管理实现
│ └── rag # RAG增强实现
│
├── langchain4j-open-ai # OpenAI集成
├── langchain4j-azure-open-ai # Azure OpenAI集成
├── langchain4j-ollama # Ollama集成
├── langchain4j-pgvector # PostgreSQL向量存储
├── langchain4j-redis # Redis向量存储
├── langchain4j-chroma # Chroma向量数据库
└── langchain4j-spring-boot-starter # Spring Boot自动配置
架构分层
┌─────────────────────────────────────────────────────┐
│ Application Layer (Your Code) │
│ @RestController, @Service, Business Logic │
└────────────────────┬────────────────────────────────┘
│
┌────────────────────▼────────────────────────────────┐
│ High-Level API (langchain4j) │
│ AiServices, ConversationalChain, RetrievalChain │
└────────────────────┬────────────────────────────────┘
│
┌────────────────────▼────────────────────────────────┐
│ Core Abstractions (langchain4j-core) │
│ Interfaces: ChatLanguageModel, EmbeddingStore, etc. │
└────────────────────┬────────────────────────────────┘
│
┌────────────────────▼────────────────────────────────┐
│ Provider Implementations (langchain4j-xxx) │
│ OpenAI, Ollama, PGVector, Redis, etc. │
└─────────────────────────────────────────────────────┘
设计优势:
- 依赖倒置:应用代码依赖抽象接口,不依赖具体实现
- 可替换性:可以轻松切换不同的模型提供商或存储后端
- 独立演进:各个模块可以独立版本管理和发布
- 按需引入:只需引入实际使用的模块,减小依赖体积
🔑 核心接口设计
LangChain4J的核心接口设计遵循单一职责原则,每个接口都有明确的职责边界。
ChatLanguageModel接口
这是所有聊天模型的顶层接口:
package dev.langchain4j.model.chat;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.output.Response;
import java.util.List;
/**
* 聊天语言模型的核心接口
*/
public interface ChatLanguageModel {
/**
* 生成单次响应
* @param messages 对话消息列表
* @return 模型响应
*/
Response<AiMessage> generate(List<ChatMessage> messages);
/**
* 生成简单文本响应(便捷方法)
* @param userMessage 用户消息
* @return 响应文本
*/
default String generate(String userMessage) {
UserMessage message = UserMessage.from(userMessage);
Response<AiMessage> response = generate(List.of(message));
return response.content().text();
}
}
支持流式响应的扩展接口:
package dev.langchain4j.model.chat;
import dev.langchain4j.data.message.ChatMessage;
import java.util.List;
/**
* 支持流式输出的聊天模型
*/
public interface StreamingChatLanguageModel {
/**
* 流式生成响应
* @param messages 对话消息列表
* @param handler 流式响应处理器
*/
void generate(List<ChatMessage> messages,
StreamingResponseHandler<AiMessage> handler);
}
具体实现示例(OpenAI):
public class OpenAiChatModel implements ChatLanguageModel {
private final String apiKey;
private final String modelName;
private final Double temperature;
private final Integer maxTokens;
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
// 1. 转换消息格式
List<OpenAiMessage> openAiMessages = messages.stream()
.map(this::toOpenAiMessage)
.collect(Collectors.toList());
// 2. 构建请求
OpenAiChatCompletionRequest request = OpenAiChatCompletionRequest.builder()
.model(modelName)
.messages(openAiMessages)
.temperature(temperature)
.maxTokens(maxTokens)
.build();
// 3. 调用API
OpenAiChatCompletionResponse response = openAiClient.chatCompletion(request);
// 4. 转换响应
return Response.from(
AiMessage.from(response.choices().get(0).message().content()),
response.tokenUsage(),
response.finishReason()
);
}
}
EmbeddingModel接口
用于文本向量化的核心接口:
package dev.langchain4j.model.embedding;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import java.util.List;
/**
* 文本嵌入模型接口
*/
public interface EmbeddingModel {
/**
* 将单个文本转换为向量
*/
Response<Embedding> embed(String text);
/**
* 将文本段转换为向量
*/
Response<Embedding> embed(TextSegment textSegment);
/**
* 批量向量化(提高效率)
*/
Response<List<Embedding>> embedAll(List<TextSegment> textSegments);
/**
* 返回向量维度
*/
int dimension();
}
实现示例:
public class OpenAiEmbeddingModel implements EmbeddingModel {
private final String apiKey;
private final String modelName; // text-embedding-ada-002
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<String> texts = textSegments.stream()
.map(TextSegment::text)
.collect(Collectors.toList());
OpenAiEmbeddingRequest request = OpenAiEmbeddingRequest.builder()
.model(modelName)
.input(texts)
.build();
OpenAiEmbeddingResponse response = openAiClient.embedding(request);
List<Embedding> embeddings = response.data().stream()
.map(data -> Embedding.from(data.embedding()))
.collect(Collectors.toList());
return Response.from(embeddings, response.usage());
}
@Override
public int dimension() {
return 1536; // ada-002的维度
}
}
EmbeddingStore接口
向量存储的抽象接口:
package dev.langchain4j.store.embedding;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import java.util.List;
/**
* 向量存储接口
*/
public interface EmbeddingStore<Embedded> {
/**
* 添加单个向量
* @return 向量ID
*/
String add(Embedding embedding);
/**
* 添加向量及其关联的嵌入对象
*/
String add(Embedding embedding, Embedded embedded);
/**
* 批量添加
*/
List<String> addAll(List<Embedding> embeddings);
/**
* 相似度搜索
* @param referenceEmbedding 查询向量
* @param maxResults 最大结果数
* @param minScore 最小相似度分数
* @return 搜索结果列表
*/
List<EmbeddingMatch<Embedded>> findRelevant(
Embedding referenceEmbedding,
int maxResults,
double minScore
);
}
搜索结果封装:
public class EmbeddingMatch<Embedded> {
private final double score; // 相似度分数
private final String embeddingId; // 向量ID
private final Embedding embedding; // 向量本身
private final Embedded embedded; // 关联的对象(如TextSegment)
// 构造器、getter等
}
🪄 AiServices动态代理机制
AiServices是LangChain4J最强大的功能之一,它使用Java动态代理将接口方法转换为LLM调用。
使用示例回顾
interface Assistant {
@SystemMessage("你是一个友好的助手")
String chat(String userMessage);
}
Assistant assistant = AiServices.create(Assistant.class, chatModel);
String response = assistant.chat("你好");
动态代理实现原理
package dev.langchain4j.service;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
public class AiServices {
/**
* 创建AI服务代理
*/
public static <T> T create(Class<T> aiServiceClass,
ChatLanguageModel chatModel) {
return (T) Proxy.newProxyInstance(
aiServiceClass.getClassLoader(),
new Class[]{aiServiceClass},
new AiServiceInvocationHandler(chatModel)
);
}
/**
* 调用处理器
*/
private static class AiServiceInvocationHandler implements InvocationHandler {
private final ChatLanguageModel chatModel;
private final ChatMemory chatMemory;
private final List<ToolSpecification> tools;
@Override
public Object invoke(Object proxy, Method method, Object[] args)
throws Throwable {
// 1. 解析方法注解
ServiceMethodInfo methodInfo = parseMethod(method);
// 2. 构建消息列表
List<ChatMessage> messages = buildMessages(methodInfo, args);
// 3. 如果有tool,添加tool规范
if (!tools.isEmpty()) {
messages = enrichWithTools(messages, tools);
}
// 4. 调用模型
Response<AiMessage> response = chatModel.generate(messages);
// 5. 处理tool调用
if (response.content().hasToolExecutionRequests()) {
response = handleToolExecution(response);
}
// 6. 更新记忆
if (chatMemory != null) {
chatMemory.add(messages);
chatMemory.add(response.content());
}
// 7. 转换返回值
return convertReturnValue(response, method.getReturnType());
}
}
}
方法注解解析
/**
* 方法元信息
*/
class ServiceMethodInfo {
String systemMessage;
String userMessageTemplate;
Map<String, Integer> variablePositions;
static ServiceMethodInfo parse(Method method) {
ServiceMethodInfo info = new ServiceMethodInfo();
// 解析@SystemMessage
if (method.isAnnotationPresent(SystemMessage.class)) {
SystemMessage annotation = method.getAnnotation(SystemMessage.class);
info.systemMessage = annotation.value();
}
// 解析@UserMessage
if (method.isAnnotationPresent(UserMessage.class)) {
UserMessage annotation = method.getAnnotation(UserMessage.class);
info.userMessageTemplate = annotation.value();
// 解析模板变量
info.variablePositions = parseTemplateVariables(
annotation.value(),
method.getParameters()
);
}
return info;
}
}
消息构建流程
class MessageBuilder {
List<ChatMessage> buildMessages(ServiceMethodInfo methodInfo, Object[] args) {
List<ChatMessage> messages = new ArrayList<>();
// 1. 添加系统消息
if (methodInfo.systemMessage != null) {
messages.add(SystemMessage.from(methodInfo.systemMessage));
}
// 2. 从记忆中恢复历史消息
if (chatMemory != null) {
messages.addAll(chatMemory.messages());
}
// 3. 构建用户消息
String userMessage = buildUserMessage(methodInfo, args);
messages.add(UserMessage.from(userMessage));
return messages;
}
String buildUserMessage(ServiceMethodInfo methodInfo, Object[] args) {
if (methodInfo.userMessageTemplate != null) {
// 使用模板替换变量
return replaceVariables(
methodInfo.userMessageTemplate,
methodInfo.variablePositions,
args
);
} else {
// 直接使用第一个参数
return String.valueOf(args[0]);
}
}
}
Builder模式配置
public class AiServices<T> {
public static <T> AiServicesBuilder<T> builder(Class<T> aiServiceClass) {
return new AiServicesBuilder<>(aiServiceClass);
}
public static class AiServicesBuilder<T> {
private Class<T> aiServiceClass;
private ChatLanguageModel chatLanguageModel;
private StreamingChatLanguageModel streamingChatLanguageModel;
private ChatMemory chatMemory;
private ChatMemoryProvider chatMemoryProvider;
private List<Object> tools = new ArrayList<>();
private ContentRetriever contentRetriever;
public AiServicesBuilder<T> chatLanguageModel(
ChatLanguageModel model) {
this.chatLanguageModel = model;
return this;
}
public AiServicesBuilder<T> chatMemory(ChatMemory chatMemory) {
this.chatMemory = chatMemory;
return this;
}
public AiServicesBuilder<T> tools(Object... tools) {
this.tools.addAll(Arrays.asList(tools));
return this;
}
public T build() {
// 验证必需配置
validateConfiguration();
// 解析工具
List<ToolSpecification> toolSpecs = parseTools(tools);
// 创建代理
return createProxy(
aiServiceClass,
chatLanguageModel,
chatMemory,
toolSpecs
);
}
}
}
🔧 Tool系统实现原理
Tool系统允许LLM调用Java方法来获取外部数据或执行操作。
@Tool注解处理
package dev.langchain4j.agent.tool;
import java.lang.annotation.*;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Tool {
/**
* 工具名称(发送给LLM)
*/
String name() default "";
/**
* 工具描述(帮助LLM理解何时使用)
*/
String value() default "";
}
参数注解:
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
public @interface P {
/**
* 参数描述
*/
String value();
}
工具定义示例
public class WeatherTools {
@Tool("获取指定城市的当前天气信息")
public String getCurrentWeather(
@P("城市名称,如'北京'或'上海'") String city
) {
// 实际应该调用天气API
return String.format("城市:%s, 温度:22°C, 天气:晴", city);
}
@Tool("获取未来几天的天气预报")
public String getWeatherForecast(
@P("城市名称") String city,
@P("预报天数,1-7之间") int days
) {
return String.format("城市:%s, %d天预报:...", city, days);
}
}
ToolSpecification生成
/**
* 工具规范(发送给LLM的工具描述)
*/
public class ToolSpecification {
private String name;
private String description;
private ToolParameters parameters;
/**
* 从Java方法解析工具规范
*/
public static ToolSpecification from(Method method) {
Tool toolAnnotation = method.getAnnotation(Tool.class);
String name = toolAnnotation.name().isEmpty()
? method.getName()
: toolAnnotation.name();
String description = toolAnnotation.value();
// 解析参数
ToolParameters parameters = parseParameters(method);
return new ToolSpecification(name, description, parameters);
}
private static ToolParameters parseParameters(Method method) {
Map<String, ToolParameter> properties = new HashMap<>();
List<String> required = new ArrayList<>();
Parameter[] parameters = method.getParameters();
for (int i = 0; i < parameters.length; i++) {
Parameter param = parameters[i];
P annotation = param.getAnnotation(P.class);
String paramName = param.getName();
String description = annotation != null ? annotation.value() : "";
String type = mapJavaTypeToJsonType(param.getType());
properties.put(paramName, new ToolParameter(type, description));
required.add(paramName);
}
return new ToolParameters("object", properties, required);
}
}
LLM眼中的工具
转换为JSON Schema格式发送给LLM:
{
"type": "function",
"function": {
"name": "getCurrentWeather",
"description": "获取指定城市的当前天气信息",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "城市名称,如'北京'或'上海'"
}
},
"required": ["city"]
}
}
}
工具执行流程
class ToolExecutor {
private final Map<String, ToolExecutionRequest> toolRegistry;
/**
* 执行工具调用
*/
String executeTool(ToolExecutionRequest request) {
// 1. 查找工具方法
ToolExecutionRequest toolMethod = toolRegistry.get(request.name());
if (toolMethod == null) {
return "Error: Tool not found: " + request.name();
}
try {
// 2. 解析参数
Object[] args = parseArguments(
toolMethod.method(),
request.arguments()
);
// 3. 执行方法
Object result = toolMethod.method().invoke(
toolMethod.instance(),
args
);
// 4. 返回结果
return String.valueOf(result);
} catch (Exception e) {
return "Error executing tool: " + e.getMessage();
}
}
/**
* 处理包含工具调用的响应
*/
Response<AiMessage> handleToolExecution(Response<AiMessage> response) {
List<ChatMessage> messages = new ArrayList<>();
messages.add(response.content());
// 执行所有工具调用
for (ToolExecutionRequest request :
response.content().toolExecutionRequests()) {
String result = executeTool(request);
messages.add(ToolExecutionResultMessage.from(
request.id(),
request.name(),
result
));
}
// 再次调用LLM处理工具结果
return chatLanguageModel.generate(messages);
}
}
🧠 Memory系统架构
Memory系统负责管理对话历史,是实现上下文感知对话的关键。
ChatMemory接口
package dev.langchain4j.memory;
import dev.langchain4j.data.message.ChatMessage;
import java.util.List;
/**
* 聊天记忆接口
*/
public interface ChatMemory {
/**
* 获取聊天记忆的唯一标识
*/
Object id();
/**
* 添加消息到记忆
*/
void add(ChatMessage message);
/**
* 批量添加消息
*/
default void add(List<ChatMessage> messages) {
messages.forEach(this::add);
}
/**
* 获取所有消息
*/
List<ChatMessage> messages();
/**
* 清空记忆
*/
void clear();
}
MessageWindowChatMemory实现
基于滑动窗口的记忆实现,保留最近N条消息:
public class MessageWindowChatMemory implements ChatMemory {
private final Object id;
private final int maxMessages;
private final LinkedList<ChatMessage> messages;
private final ChatMemoryStore store; // 可选的持久化存储
public MessageWindowChatMemory(int maxMessages) {
this.id = UUID.randomUUID();
this.maxMessages = maxMessages;
this.messages = new LinkedList<>();
this.store = null;
}
@Override
public void add(ChatMessage message) {
messages.add(message);
// 超过最大消息数,移除最旧的
while (messages.size() > maxMessages) {
messages.removeFirst();
}
// 持久化到存储
if (store != null) {
store.updateMessages(id, messages);
}
}
@Override
public List<ChatMessage> messages() {
return new ArrayList<>(messages);
}
@Override
public void clear() {
messages.clear();
if (store != null) {
store.deleteMessages(id);
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private int maxMessages = 10;
private Object id;
private ChatMemoryStore store;
public Builder maxMessages(int maxMessages) {
this.maxMessages = maxMessages;
return this;
}
public Builder id(Object id) {
this.id = id;
return this;
}
public Builder chatMemoryStore(ChatMemoryStore store) {
this.store = store;
return this;
}
public MessageWindowChatMemory build() {
return new MessageWindowChatMemory(this);
}
}
}
TokenWindowChatMemory实现
基于token数量的记忆实现,确保不超过模型上下文限制:
public class TokenWindowChatMemory implements ChatMemory {
private final Object id;
private final int maxTokens;
private final Tokenizer tokenizer;
private final LinkedList<ChatMessage> messages;
@Override
public void add(ChatMessage message) {
messages.add(message);
// 计算当前token总数
int totalTokens = calculateTotalTokens();
// 移除旧消息直到token数在限制内
while (totalTokens > maxTokens && messages.size() > 1) {
ChatMessage removed = messages.removeFirst();
totalTokens -= tokenizer.estimateTokenCount(removed.text());
}
}
private int calculateTotalTokens() {
return messages.stream()
.mapToInt(msg -> tokenizer.estimateTokenCount(msg.text()))
.sum();
}
}
ChatMemoryProvider接口
多用户场景下的记忆提供者:
/**
* 为不同用户提供独立的聊天记忆
*/
public interface ChatMemoryProvider {
/**
* 根据记忆ID获取或创建聊天记忆
*/
ChatMemory get(Object memoryId);
}
实现示例:
public class SimpleChatMemoryProvider implements ChatMemoryProvider {
private final Map<Object, ChatMemory> memories = new ConcurrentHashMap<>();
private final Supplier<ChatMemory> chatMemoryFactory;
public SimpleChatMemoryProvider(Supplier<ChatMemory> factory) {
this.chatMemoryFactory = factory;
}
@Override
public ChatMemory get(Object memoryId) {
return memories.computeIfAbsent(memoryId,
id -> chatMemoryFactory.get());
}
}
持久化存储接口
package dev.langchain4j.store.memory.chat;
import dev.langchain4j.data.message.ChatMessage;
import java.util.List;
/**
* 聊天记忆持久化存储
*/
public interface ChatMemoryStore {
/**
* 获取指定记忆ID的消息
*/
List<ChatMessage> getMessages(Object memoryId);
/**
* 更新消息
*/
void updateMessages(Object memoryId, List<ChatMessage> messages);
/**
* 删除消息
*/
void deleteMessages(Object memoryId);
}
📦 RAG系统设计
RAG(检索增强生成)系统通过检索相关文档来增强LLM的回答能力。
ContentRetriever接口
package dev.langchain4j.rag;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.content.Content;
import java.util.List;
/**
* 内容检索器接口
*/
public interface ContentRetriever {
/**
* 根据查询检索相关内容
* @param query 用户查询
* @return 相关内容列表
*/
List<Content> retrieve(Query query);
}
默认检索器实现
public class EmbeddingStoreContentRetriever implements ContentRetriever {
private final EmbeddingStore<TextSegment> embeddingStore;
private final EmbeddingModel embeddingModel;
private final int maxResults;
private final double minScore;
@Override
public List<Content> retrieve(Query query) {
// 1. 将查询向量化
Embedding queryEmbedding = embeddingModel.embed(query.text()).content();
// 2. 在向量存储中搜索
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.findRelevant(
queryEmbedding,
maxResults,
minScore
);
// 3. 转换为Content对象
return matches.stream()
.map(match -> Content.from(match.embedded().text()))
.collect(Collectors.toList());
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private EmbeddingStore<TextSegment> embeddingStore;
private EmbeddingModel embeddingModel;
private int maxResults = 3;
private double minScore = 0.7;
// Builder方法...
public EmbeddingStoreContentRetriever build() {
return new EmbeddingStoreContentRetriever(this);
}
}
}
RetrievalAugmentor接口
增强检索管道,支持查询转换、内容重排等高级功能:
package dev.langchain4j.rag;
import dev.langchain4j.data.message.UserMessage;
import java.util.List;
/**
* 检索增强器
*/
public interface RetrievalAugmentor {
/**
* 增强用户消息
* @param userMessage 原始用户消息
* @return 增强后的用户消息
*/
UserMessage augment(UserMessage userMessage, Metadata metadata);
}
默认增强器实现
public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
private final QueryTransformer queryTransformer;
private final QueryRouter queryRouter;
private final ContentRetriever contentRetriever;
private final ContentAggregator contentAggregator;
private final ContentInjector contentInjector;
@Override
public UserMessage augment(UserMessage userMessage, Metadata metadata) {
// 1. 转换查询(扩展、改写等)
Query query = Query.from(userMessage.text());
if (queryTransformer != null) {
query = queryTransformer.transform(query);
}
// 2. 路由到合适的检索器(多知识库场景)
ContentRetriever retriever = contentRetriever;
if (queryRouter != null) {
retriever = queryRouter.route(query);
}
// 3. 检索相关内容
List<Content> contents = retriever.retrieve(query);
// 4. 聚合内容(去重、合并等)
if (contentAggregator != null) {
contents = contentAggregator.aggregate(contents);
}
// 5. 注入到用户消息
if (contentInjector != null) {
return contentInjector.inject(contents, userMessage);
}
// 默认注入方式
return injectDefault(contents, userMessage);
}
private UserMessage injectDefault(List<Content> contents,
UserMessage userMessage) {
String contextText = contents.stream()
.map(Content::textSegment)
.map(TextSegment::text)
.collect(Collectors.joining("\n\n"));
String augmentedText = String.format(
"相关信息:\n%s\n\n用户问题:%s",
contextText,
userMessage.text()
);
return UserMessage.from(augmentedText);
}
}
查询转换器
/**
* 查询转换器(扩展、改写、分解等)
*/
public interface QueryTransformer {
Query transform(Query query);
}
/**
* 使用LLM扩展查询
*/
public class QueryExpander implements QueryTransformer {
private final ChatLanguageModel chatModel;
@Override
public Query transform(Query query) {
String prompt = String.format(
"将以下查询扩展为更详细的搜索查询:\n%s",
query.text()
);
String expandedQuery = chatModel.generate(prompt);
return Query.from(expandedQuery);
}
}
RAG完整管道示例
// 构建完整的RAG系统
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
// 查询转换
.queryTransformer(new QueryExpander(chatModel))
// 内容检索
.contentRetriever(
EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(5)
.minScore(0.7)
.build()
)
// 内容重排(可选)
.contentAggregator(new ReRankingContentAggregator(reRanker))
.build();
// 集成到AI服务
Assistant assistant = AiServices.builder(Assistant.class)
.chatLanguageModel(chatModel)
.retrievalAugmentor(retrievalAugmentor)
.build();
🎨 设计模式应用
LangChain4J框架中应用了多种经典设计模式。
1. Builder模式
几乎所有复杂对象都使用Builder模式构建:
// 链式调用,清晰易读
OpenAiChatModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("gpt-4")
.temperature(0.7)
.maxTokens(1000)
.timeout(Duration.ofSeconds(30))
.logRequests(true)
.logResponses(true)
.build();
优势:
- 参数众多时保持可读性
- 支持可选参数
- 提供参数验证
- 支持不可变对象
2. Strategy模式
不同的检索、转换、注入策略可以互换:
// 策略接口
public interface QueryTransformer {
Query transform(Query query);
}
// 不同策略实现
public class QueryExpander implements QueryTransformer { }
public class QueryCompressor implements QueryTransformer { }
public class QueryRewriter implements QueryTransformer { }
// 运行时选择策略
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(new QueryExpander(chatModel)) // 可替换
.build();
3. Template Method模式
定义算法骨架,子类实现具体步骤:
public abstract class BaseContentRetriever implements ContentRetriever {
@Override
public final List<Content> retrieve(Query query) {
// 模板方法定义流程
Query transformedQuery = preProcess(query);
List<Content> contents = doRetrieve(transformedQuery);
return postProcess(contents);
}
// 钩子方法,子类可选择性覆盖
protected Query preProcess(Query query) {
return query;
}
// 抽象方法,子类必须实现
protected abstract List<Content> doRetrieve(Query query);
protected List<Content> postProcess(List<Content> contents) {
return contents;
}
}
4. Chain of Responsibility模式
多个处理器链式处理请求:
/**
* 内容过滤器链
*/
public interface ContentFilter {
List<Content> filter(List<Content> contents);
}
public class ContentFilterChain {
private final List<ContentFilter> filters;
public List<Content> filter(List<Content> contents) {
List<Content> result = contents;
for (ContentFilter filter : filters) {
result = filter.filter(result);
}
return result;
}
}
// 使用
ContentFilterChain chain = new ContentFilterChain(List.of(
new DuplicateRemovalFilter(),
new LengthFilter(1000),
new RelevanceFilter(0.8)
));
5. Factory模式
创建不同类型的对象:
/**
* 记忆工厂
*/
public class ChatMemoryFactory {
public static ChatMemory createWindowMemory(int maxMessages) {
return MessageWindowChatMemory.builder()
.maxMessages(maxMessages)
.build();
}
public static ChatMemory createTokenWindowMemory(int maxTokens) {
return TokenWindowChatMemory.builder()
.maxTokens(maxTokens)
.tokenizer(new OpenAiTokenizer())
.build();
}
}
6. Observer模式
流式响应使用观察者模式:
// 观察者接口
public interface StreamingResponseHandler<T> {
void onNext(String token);
void onComplete(Response<T> response);
void onError(Throwable error);
}
// 使用
streamingChatModel.generate(messages, new StreamingResponseHandler<>() {
@Override
public void onNext(String token) {
System.out.print(token); // 实时输出
}
@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("\n完成!");
}
@Override
public void onError(Throwable error) {
System.err.println("错误:" + error.getMessage());
}
});
7. Adapter模式
适配不同提供商的API:
// 统一接口
public interface ChatLanguageModel {
Response<AiMessage> generate(List<ChatMessage> messages);
}
// OpenAI适配器
public class OpenAiChatModel implements ChatLanguageModel {
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
// 转换为OpenAI格式
// 调用OpenAI API
// 转换回统一格式
}
}
// Ollama适配器
public class OllamaChatModel implements ChatLanguageModel {
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
// 转换为Ollama格式
// 调用Ollama API
// 转换回统一格式
}
}
🔌 扩展点和SPI
LangChain4J提供了丰富的扩展点,支持自定义实现。
自定义ChatLanguageModel
public class CustomChatModel implements ChatLanguageModel {
private final String endpoint;
private final HttpClient httpClient;
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
// 1. 转换消息格式
String requestBody = convertToCustomFormat(messages);
// 2. 调用自定义API
HttpResponse<String> response = httpClient.send(
HttpRequest.newBuilder()
.uri(URI.create(endpoint))
.POST(HttpRequest.BodyPublishers.ofString(requestBody))
.build(),
HttpResponse.BodyHandlers.ofString()
);
// 3. 解析响应
return parseResponse(response.body());
}
private String convertToCustomFormat(List<ChatMessage> messages) {
// 实现自定义格式转换
return "...";
}
private Response<AiMessage> parseResponse(String responseBody) {
// 实现响应解析
return Response.from(AiMessage.from("..."));
}
}
自定义EmbeddingStore
public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
private final RedisClient redis;
private final String indexName;
@Override
public String add(Embedding embedding, TextSegment embedded) {
String id = UUID.randomUUID().toString();
// 存储到Redis
Map<String, String> hash = new HashMap<>();
hash.put("text", embedded.text());
hash.put("embedding", serializeEmbedding(embedding));
redis.hset(indexName + ":" + id, hash);
return id;
}
@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(
Embedding referenceEmbedding,
int maxResults,
double minScore) {
// 使用Redis向量搜索
String query = String.format(
"FT.SEARCH %s * => [KNN %d @embedding $vector]",
indexName, maxResults
);
// 执行搜索并转换结果
return executeSearchAndConvert(query, referenceEmbedding);
}
}
自定义ContentRetriever
public class HybridContentRetriever implements ContentRetriever {
private final EmbeddingStoreContentRetriever vectorRetriever;
private final KeywordSearchRetriever keywordRetriever;
private final double vectorWeight;
private final double keywordWeight;
@Override
public List<Content> retrieve(Query query) {
// 1. 向量检索
List<Content> vectorResults = vectorRetriever.retrieve(query);
// 2. 关键词检索
List<Content> keywordResults = keywordRetriever.retrieve(query);
// 3. 混合排序
return mergeAndRank(vectorResults, keywordResults);
}
private List<Content> mergeAndRank(List<Content> vector,
List<Content> keyword) {
// 实现混合排序算法(如RRF)
Map<Content, Double> scores = new HashMap<>();
// 向量检索分数
for (int i = 0; i < vector.size(); i++) {
scores.merge(vector.get(i),
vectorWeight / (i + 1), Double::sum);
}
// 关键词检索分数
for (int i = 0; i < keyword.size(); i++) {
scores.merge(keyword.get(i),
keywordWeight / (i + 1), Double::sum);
}
// 按分数排序
return scores.entrySet().stream()
.sorted(Map.Entry.<Content, Double>comparingByValue().reversed())
.map(Map.Entry::getKey)
.collect(Collectors.toList());
}
}
自定义Tokenizer
public class CustomTokenizer implements Tokenizer {
@Override
public int estimateTokenCount(String text) {
// 实现token计数逻辑
// 简单方法:字符数 / 4(英文)
// 或使用专门的tokenizer库
return text.length() / 4;
}
@Override
public List<Integer> encode(String text) {
// 实现文本编码
return List.of(/* token ids */);
}
@Override
public String decode(List<Integer> tokens) {
// 实现token解码
return "...";
}
}
Spring Boot自动配置扩展
@Configuration
@ConditionalOnClass(ChatLanguageModel.class)
@EnableConfigurationProperties(CustomModelProperties.class)
public class CustomModelAutoConfiguration {
@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty("custom.model.api-key")
public ChatLanguageModel customChatModel(
CustomModelProperties properties) {
return CustomChatModel.builder()
.apiKey(properties.getApiKey())
.endpoint(properties.getEndpoint())
.modelName(properties.getModelName())
.build();
}
}
配置类
@ConfigurationProperties(prefix = "custom.model")
public class CustomModelProperties {
private String apiKey;
private String endpoint;
private String modelName;
// getters and setters
}
💡 实战练习
练习1:实现自定义记忆策略
实现一个基于重要性的记忆管理器,保留重要消息:
/**
* 任务:实现ImportanceBasedChatMemory
* - 给每条消息评估重要性分数
* - 当超过限制时,删除不重要的消息
* - 系统消息和包含特定关键词的消息优先保留
*/
public class ImportanceBasedChatMemory implements ChatMemory {
// 你的实现
}
练习2:实现查询路由器
实现根据查询内容路由到不同知识库的路由器:
/**
* 任务:实现QueryRouter
* - 分析查询内容
* - 路由到最合适的ContentRetriever
* - 支持多知识库并行检索
*/
public class SmartQueryRouter implements QueryRouter {
// 你的实现
}
练习3:实现自定义Tool执行器
实现支持异步Tool执行的执行器:
/**
* 任务:实现AsyncToolExecutor
* - 支持异步执行Tool
* - 支持Tool执行超时控制
* - 支持Tool执行重试
*/
public class AsyncToolExecutor {
// 你的实现
}
练习4:实现内容重排器
使用交叉编码器对检索结果重新排序:
/**
* 任务:实现ReRankingContentAggregator
* - 使用重排序模型对结果重新评分
* - 返回最相关的top-k结果
*/
public class ReRankingContentAggregator implements ContentAggregator {
// 你的实现
}
最后更新:2026-03-09 字数统计:5,500 字 预计阅读时间:45 分钟