原文链接说明# SpringAI(GA):RAG快速上手+模块化解读
教程说明
说明:本教程将采用2025年5月20日正式的GA版,给出如下内容
- 核心功能模块的快速上手教程
- 核心功能模块的源码级解读
- Spring ai alibaba增强的快速上手教程 + 源码级解读
版本:JDK21 + SpringBoot3.4.5 + SpringAI 1.0.0 + SpringAI Alibaba 1.0.0.2
将陆续完成如下章节教程。本章是第六章(Rag增强问答质量)下的快速上手+Rag模块化源码解读
代码开源如下:github.com/GTyingzi/sp…
Rag快速上手
[!TIP] RAG(Retrieval-Augmented Generation,检索增强生成) ,该技术通过从外部知识库中检索相关信息,并将其作为提示(Prompt)输入给大型语言模型(LLMs),以增强模型处理知识密集型任务的能力 向量数据库快速上手可见《第五章:向量数据库》
以下结合内存向量数据库实现 RAG 的典型案例:Pre-Retrieval、Retrieval、Generation 等场景
实战代码可见:github.com/GTyingzi/sp… 下的 rag/rag-simple
对于各类数据源层面的《RAG 的 ETL Pipeline 快速上手》
RAG 模块化源码解读:《Rag 模块化源码篇》
pom.xml
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-openai</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-rag</artifactId>
</dependency>
</dependencies>
application.yml
server:
port: 8080
spring:
application:
name: rag-simple
ai:
openai:
api-key: ${DASHSCOPEAPIKEY}
base-url: https://dashscope.aliyuncs.com/compatible-mode
chat:
options:
model: qwen-max
embedding:
options:
model: text-embedding-v1
RAG 效果简单对比
RagSimpleController
package com.spring.ai.tutorial.rag.controller;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RestController
@RequestMapping("/rag/simple")
public class RagSimpleController {
private static final Logger logger = LoggerFactory.getLogger(RagSimpleController.class);
private final SimpleVectorStore simpleVectorStore;
private final ChatClient chatClient;
public RagSimpleController(EmbeddingModel embeddingModel, ChatClient.Builder builder) {
this.simpleVectorStore = SimpleVectorStore
.builder(embeddingModel).build();
this.chatClient = builder.build();
}
@GetMapping("/add")
public void add() {
logger.info("start add data");
HashMap<String, Object> map = new HashMap<>();
map.put("year", 2025);
map.put("name", "yingzi");
List<Document> documents = List.of(
new Document("你的姓名是影子,湖南邵阳人,25年硕士毕业于北京科技大学,曾先后在百度、理想、快手实习,曾发表过一篇自然语言处理的sci,现在是一名AI研发工程师"),
new Document("你的姓名是影子,专业领域包含的数学、前后端、大数据、自然语言处理", Map.of("year", 2024)),
new Document("你姓名是影子,爱好是发呆、思考、运动", map));
simpleVectorStore.add(documents);
}
@GetMapping("/chat")
public String chat(@RequestParam(value = "query", defaultValue = "你好,请告诉我影子这个人的身份信息") String query) {
logger.info("start chat");
return chatClient.prompt(query).call().content();
}
@GetMapping("/chat-rag-advisor")
public String chatRagAdvisor(@RequestParam(value = "query", defaultValue = "你好,请告诉我影子这个人的身份信息") String query) {
logger.info("start chat with rag-advisor");
RetrievalAugmentationAdvisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder()
.vectorStore(simpleVectorStore)
.build())
.build();
return chatClient.prompt(query)
.advisors(retrievalAugmentationAdvisor)
.call().content();
}
}
效果
直接询问,并不知道“影子”是谁
在 RAG 增强下,得知了“影子”
RAG 模块化案例
RAG 可以由一组模块化组件构成 Rag 模块化源码篇,结构化的工作流程保障 AI 模型生成质量
DocumentSelectFirst
package com.spring.ai.tutorial.rag.service;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import java.util.Collections;
import java.util.List;
public class DocumentSelectFirst implements DocumentPostProcessor {
@Override
public List<Document> process(Query query, List<Document> documents) {
return Collections.singletonList(documents.get(0));
}
}
实现 DocumentPostProcessor 接口,从文档中挑选第一个
RagModuleController
package com.spring.ai.tutorial.rag.controller;
import com.spring.ai.tutorial.rag.service.DocumentSelectFirst;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RestController
@RequestMapping("/rag/module")
public class RagModuleController {
private static final Logger logger = LoggerFactory.getLogger(RagSimpleController.class);
private final SimpleVectorStore simpleVectorStore;
private final ChatClient.Builder chatClientBuilder;
public RagModuleController(EmbeddingModel embeddingModel, ChatClient.Builder builder) {
this.simpleVectorStore = SimpleVectorStore
.builder(embeddingModel).build();
this.chatClientBuilder = builder;
}
@GetMapping("/add")
public void add() {
logger.info("start add data");
HashMap<String, Object> map = new HashMap<>();
map.put("year", 2025);
map.put("name", "yingzi");
List<Document> documents = List.of(
new Document("你的姓名是影子,湖南邵阳人,25年硕士毕业于北京科技大学,曾先后在百度、理想、快手实习,曾发表过一篇自然语言处理的sci,现在是一名AI研发工程师"),
new Document("你的姓名是影子,专业领域包含的数学、前后端、大数据、自然语言处理", Map.of("year", 2024)),
new Document("你姓名是影子,爱好是发呆、思考、运动", map));
simpleVectorStore.add(documents);
}
@GetMapping("/chat-rag-advisor")
public String chatRagAdvisor(@RequestParam(value = "query", defaultValue = "你好,请告诉我影子这个人的身份信息") String query) {
logger.info("start chat with rag-advisor");
// 1. Pre-Retrieval
// 1.1 MultiQueryExpander
MultiQueryExpander multiQueryExpander = MultiQueryExpander.builder()
.chatClientBuilder(this.chatClientBuilder)
.build();
// 1.2 TranslationQueryTransformer
TranslationQueryTransformer translationQueryTransformer = TranslationQueryTransformer.builder()
.chatClientBuilder(this.chatClientBuilder)
.targetLanguage("English")
.build();
// 2. Retrieval
// 2.1 VectorStoreDocumentRetriever
VectorStoreDocumentRetriever vectorStoreDocumentRetriever = VectorStoreDocumentRetriever.builder()
.vectorStore(simpleVectorStore)
.build();
// 2.2 ConcatenationDocumentJoiner
ConcatenationDocumentJoiner concatenationDocumentJoiner = new ConcatenationDocumentJoiner();
// 3. Post-Retrieval
// 3.1 DocumentSelectFirst
DocumentSelectFirst documentSelectFirst = new DocumentSelectFirst();
// 4. Generation
// 4.1 ContextualQueryAugmenter
ContextualQueryAugmenter contextualQueryAugmenter = ContextualQueryAugmenter.builder()
.allowEmptyContext(true)
.build();
RetrievalAugmentationAdvisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
// 扩充为原来的3倍
.queryExpander(multiQueryExpander)
// 转为英文
.queryTransformers(translationQueryTransformer)
// 丛向向量存储中检索文档
.documentRetriever(vectorStoreDocumentRetriever)
// 将检索到的文档进行拼接
.documentJoiner(concatenationDocumentJoiner)
// 对检索到的文档进行处理,选择第一个
.documentPostProcessors(documentSelectFirst)
// 对生成的查询进行上下文增强
.queryAugmenter(contextualQueryAugmenter)
.build();
return this.chatClientBuilder.build().prompt(query)
.advisors(retrievalAugmentationAdvisor)
.call().content();
}
}
在这个例子中,我们使用了所有的 RAG 模块组件
Pre-Retrieval
- 扩充问题:MultiQueryExpander
- 翻译为英文:TranslationQueryTransformer
Retrieval
- 从向量存储中检索文档:VectorStoreDocumentRetriever
- 将检索到的文档进行拼接:ConcatenationDocumentJoiner
Post-Retrieval
- 选择第一个文档:DocumentSelectFirst
Generation
- 对生成的查询进行上下文增强:ContextualQueryAugmenter
效果
首先,进来的 originalQuery 的原始文本为“你好,请告诉我影子这个人的身份信息”
经过 TranslationQueryTransformer 翻译为英文
默认是增加 3 个,且保留原来的 1 个
从向量存储中检索文档
将检索到的文档进行拼接
选择第一个
增加的上下文信息
Rag 模块化源码篇
Spring AI 实现了一个模块化的 RAG 架构,其灵感来自于论文:Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks,本文是 RAG 模块化源码的讲解
RetrievalAugmentationAdvisor
RAG 增强器,利用模块化 RAG 组件(Query、Pre-Retrieval、Retrieval、Post-Retrieval、Generation)为用户文本添加额外信息
核心方法是 before、after
before:
- 创建原始查询(originalQuery):从用户输入的文本、参数和对话历史中构建一个 Query 对象,作为后续处理的基础
- 查询转换(transformedQuery):依次通过 queryTransformers 列表中的每个 QueryTransformer,对原始查询进行转换。每个转换器可以对查询内容进行修改(如规范化、重写等),形成最终的 transformedQuery
- 查询扩展(expandedQueries):若配置了 queryExpander,则用它将转换后的查询扩展为一个或多个查询(如同义词扩展、多轮问答等),否则只用转换后的查询本身
- 检索相关文档(documentsForQuery):对每个扩展后的查询,异步调用 getDocumentsForQuery 方法,通过 documentRetriever 检索与查询相关的文档。所有结果以 Map<Query, List<List>> 形式收集
- 文档合并(documents):使用 documentJoiner 将所有查询检索到的文档合并成一个文档列表,便于后续处理
- 文档后处理(Post-process):依次通过 documentPostProcessors 列表中的每个处理器,对合并后的文档进行进一步处理(如去重、排序、摘要等)。处理结果存入上下文 context
- 查询增强(Augment):用 queryAugmenter 将原始查询和检索到的文档结合,生成带有文档上下文信息的增强查询(如将文档内容拼接到用户问题后)
- 更新请求(Update Request):用增强后的查询内容更新 ChatClientRequest,并将文档上下文写入请求上下文,返回新的请求对象用于后续流程
after:
- 将 RAG 过程中检索到的文档添加到元数据中,键为"ragdocumentcontext"
package org.springframework.ai.rag.advisor;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.core.task.TaskExecutor;
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
import org.springframework.lang.Nullable;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;
import reactor.core.scheduler.Scheduler;
public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
public static final String DOCUMENTCONTEXT = "ragdocumentcontext";
private final List<QueryTransformer> queryTransformers;
@Nullable
private final QueryExpander queryExpander;
private final DocumentRetriever documentRetriever;
private final DocumentJoiner documentJoiner;
private final List<DocumentPostProcessor> documentPostProcessors;
private final QueryAugmenter queryAugmenter;
private final TaskExecutor taskExecutor;
private final Scheduler scheduler;
private final int order;
private RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers, @Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever, @Nullable DocumentJoiner documentJoiner, @Nullable List<DocumentPostProcessor> documentPostProcessors, @Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
this.queryExpander = queryExpander;
this.documentRetriever = documentRetriever;
this.documentJoiner = (DocumentJoiner)(documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner());
this.documentPostProcessors = documentPostProcessors != null ? documentPostProcessors : List.of();
this.queryAugmenter = (QueryAugmenter)(queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build());
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULTSCHEDULER;
this.order = order != null ? order : 0;
}
public static Builder builder() {
return new Builder();
}
public ChatClientRequest before(ChatClientRequest chatClientRequest, @Nullable AdvisorChain advisorChain) {
Map<String, Object> context = new HashMap(chatClientRequest.context());
Query originalQuery = Query.builder().text(chatClientRequest.prompt().getUserMessage().getText()).history(chatClientRequest.prompt().getInstructions()).context(context).build();
Query transformedQuery = originalQuery;
for(QueryTransformer queryTransformer : this.queryTransformers) {
transformedQuery = queryTransformer.apply(transformedQuery);
}
List<Query> expandedQueries = this.queryExpander != null ? this.queryExpander.expand(transformedQuery) : List.of(transformedQuery);
Map<Query, List<List<Document>>> documentsForQuery = (Map)expandedQueries.stream().map((query) -> CompletableFuture.supplyAsync(() -> this.getDocumentsForQuery(query), this.taskExecutor)).toList().stream().map(CompletableFuture::join).collect(Collectors.toMap(Map.Entry::getKey, (entry) -> List.of((List)entry.getValue())));
List<Document> documents = this.documentJoiner.join(documentsForQuery);
for(DocumentPostProcessor documentPostProcessor : this.documentPostProcessors) {
documents = documentPostProcessor.process(originalQuery, documents);
}
context.put("ragdocumentcontext", documents);
Query augmentedQuery = this.queryAugmenter.augment(originalQuery, documents);
return chatClientRequest.mutate().prompt(chatClientRequest.prompt().augmentUserMessage(augmentedQuery.text())).context(context).build();
}
private Map.Entry<Query, List<Document>> getDocumentsForQuery(Query query) {
List<Document> documents = this.documentRetriever.retrieve(query);
return Map.entry(query, documents);
}
public ChatClientResponse after(ChatClientResponse chatClientResponse, @Nullable AdvisorChain advisorChain) {
ChatResponse.Builder chatResponseBuilder;
if (chatClientResponse.chatResponse() == null) {
chatResponseBuilder = ChatResponse.builder();
} else {
chatResponseBuilder = ChatResponse.builder().from(chatClientResponse.chatResponse());
}
chatResponseBuilder.metadata("ragdocumentcontext", chatClientResponse.context().get("ragdocumentcontext"));
return ChatClientResponse.builder().chatResponse(chatResponseBuilder.build()).context(chatClientResponse.context()).build();
}
public Scheduler getScheduler() {
return this.scheduler;
}
public int getOrder() {
return this.order;
}
private static TaskExecutor buildDefaultTaskExecutor() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
taskExecutor.setThreadNamePrefix("ai-advisor-");
taskExecutor.setCorePoolSize(4);
taskExecutor.setMaxPoolSize(16);
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
taskExecutor.initialize();
return taskExecutor;
}
public static final class Builder {
private List<QueryTransformer> queryTransformers;
private QueryExpander queryExpander;
private DocumentRetriever documentRetriever;
private DocumentJoiner documentJoiner;
private List<DocumentPostProcessor> documentPostProcessors;
private QueryAugmenter queryAugmenter;
private TaskExecutor taskExecutor;
private Scheduler scheduler;
private Integer order;
private Builder() {
}
public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
this.queryTransformers = queryTransformers;
return this;
}
public Builder queryTransformers(QueryTransformer... queryTransformers) {
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
this.queryTransformers = Arrays.asList(queryTransformers);
return this;
}
public Builder queryExpander(QueryExpander queryExpander) {
this.queryExpander = queryExpander;
return this;
}
public Builder documentRetriever(DocumentRetriever documentRetriever) {
this.documentRetriever = documentRetriever;
return this;
}
public Builder documentJoiner(DocumentJoiner documentJoiner) {
this.documentJoiner = documentJoiner;
return this;
}
public Builder documentPostProcessors(List<DocumentPostProcessor> documentPostProcessors) {
Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements");
this.documentPostProcessors = documentPostProcessors;
return this;
}
public Builder documentPostProcessors(DocumentPostProcessor... documentPostProcessors) {
Assert.notNull(documentPostProcessors, "documentPostProcessors cannot be null");
Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements");
this.documentPostProcessors = Arrays.asList(documentPostProcessors);
return this;
}
public Builder queryAugmenter(QueryAugmenter queryAugmenter) {
this.queryAugmenter = queryAugmenter;
return this;
}
public Builder taskExecutor(TaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
return this;
}
public Builder scheduler(Scheduler scheduler) {
this.scheduler = scheduler;
return this;
}
public Builder order(Integer order) {
this.order = order;
return this;
}
public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever, this.documentJoiner, this.documentPostProcessors, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
}
}
}
Query
用于在 RAG 流程中表示查询的类
String text:查询的文本内容,用户输入的核心查询语句List<Message> history:当前查询相关的对话历史记录Map<String, Object> context:查询的上下文信息,键值对集合,用于存储与查询相关的额外数据
package org.springframework.ai.rag;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.messages.Message;
import org.springframework.util.Assert;
public record Query(String text, List<Message> history, Map<String, Object> context) {
public Query {
Assert.hasText(text, "text cannot be null or empty");
Assert.notNull(history, "history cannot be null");
Assert.noNullElements(history, "history elements cannot be null");
Assert.notNull(context, "context cannot be null");
Assert.noNullElements(context.keySet(), "context keys cannot be null");
}
public Query(String text) {
this(text, List.of(), Map.of());
}
public Builder mutate() {
return (new Builder()).text(this.text).history(this.history).context(this.context);
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private String text;
private List<Message> history = List.of();
private Map<String, Object> context = Map.of();
private Builder() {
}
public Builder text(String text) {
this.text = text;
return this;
}
public Builder history(List<Message> history) {
this.history = history;
return this;
}
public Builder history(Message... history) {
this.history = List.of(history);
return this;
}
public Builder context(Map<String, Object> context) {
this.context = context;
return this;
}
public Query build() {
return new Query(this.text, this.history, this.context);
}
}
}
Pre-Retrieval
QueryExpander(查询扩展接口类)
作用:
- 处理不规范的查询:通过提供替代的查询表达式,帮助改善查询质量
- 分解复杂问题:将复杂的查询拆分为更简单的子查询,便于后续处理
package org.springframework.ai.rag.preretrieval.query.expansion;
import java.util.List;
import java.util.function.Function;
import org.springframework.ai.rag.Query;
public interface QueryExpander extends Function<Query, List<Query>> {
List<Query> expand(Query query);
default List<Query> apply(Query query) {
return this.expand(query);
}
}
MultiQueryExpander
扩展查询的类,通过使用 LLM 将单个查询扩展为多个语义上多样化的变体,这些变体能从不同角度或方面覆盖原始查询的主题,从而增加检索到相关结果的机会
字段的含义
ChatClient chatClient:用于与大语言模型进行交互,生成查询的变体PromptTemplate promptTemplate:定义生成查询变体的提示模版。默认模板要求生成指定数量的查询变体,每个变体需覆盖不同的视角或方面。boolean includeOriginal:是否在生成的查询列表中包含原始查询,默认为 trueint numberOfQueries:指定生成的查询变体的数量
package org.springframework.ai.rag.preretrieval.query.expansion;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
public final class MultiQueryExpander implements QueryExpander {
private static final Logger logger = LoggerFactory.getLogger(MultiQueryExpander.class);
private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("You are an expert at information retrieval and search optimization.\nYour task is to generate {number} different versions of the given query.\n\nEach variant must cover different perspectives or aspects of the topic,\nwhile maintaining the core intent of the original query. The goal is to\nexpand the search space and improve the chances of finding relevant information.\n\nDo not explain your choices or add any other text.\nProvide the query variants separated by newlines.\n\nOriginal query: {query}\n\nQuery variants:\n");
private static final Boolean DEFAULTINCLUDEORIGINAL = true;
private static final Integer DEFAULTNUMBEROFQUERIES = 3;
private final ChatClient chatClient;
private final PromptTemplate promptTemplate;
private final boolean includeOriginal;
private final int numberOfQueries;
public MultiQueryExpander(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, @Nullable Boolean includeOriginal, @Nullable Integer numberOfQueries) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULTPROMPTTEMPLATE;
this.includeOriginal = includeOriginal != null ? includeOriginal : DEFAULTINCLUDEORIGINAL;
this.numberOfQueries = numberOfQueries != null ? numberOfQueries : DEFAULTNUMBEROFQUERIES;
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, new String[]{"number", "query"});
}
public List<Query> expand(Query query) {
Assert.notNull(query, "query cannot be null");
logger.debug("Generating {} query variants", this.numberOfQueries);
String response = this.chatClient.prompt().user((user) -> user.text(this.promptTemplate.getTemplate()).param("number", this.numberOfQueries).param("query", query.text())).call().content();
if (response == null) {
logger.warn("Query expansion result is null. Returning the input query unchanged.");
return List.of(query);
} else {
List<String> queryVariants = Arrays.asList(response.split("\n"));
if (!CollectionUtils.isEmpty(queryVariants) && this.numberOfQueries == queryVariants.size()) {
List<Query> queries = (List)queryVariants.stream().filter(StringUtils::hasText).map((queryText) -> query.mutate().text(queryText).build()).collect(Collectors.toList());
if (this.includeOriginal) {
logger.debug("Including the original query in the result");
queries.add(0, query);
}
return queries;
} else {
logger.warn("Query expansion result does not contain the requested {} variants. Returning the input query unchanged.", this.numberOfQueries);
return List.of(query);
}
}
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private ChatClient.Builder chatClientBuilder;
private PromptTemplate promptTemplate;
private Boolean includeOriginal;
private Integer numberOfQueries;
private Builder() {
}
public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
return this;
}
public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public Builder includeOriginal(Boolean includeOriginal) {
this.includeOriginal = includeOriginal;
return this;
}
public Builder numberOfQueries(Integer numberOfQueries) {
this.numberOfQueries = numberOfQueries;
return this;
}
public MultiQueryExpander build() {
return new MultiQueryExpander(this.chatClientBuilder, this.promptTemplate, this.includeOriginal, this.numberOfQueries);
}
}
}
QueryTransformer(查询转换接口类)
作用:
- 查询结构不完整或格式不佳
- 查询中的术语存在歧义
- 查询中使用了复杂或难以理解的词汇
- 查询使用了不受支持的语言
package org.springframework.ai.rag.preretrieval.query.transformation;
import java.util.function.Function;
import org.springframework.ai.rag.Query;
public interface QueryTransformer extends Function<Query, Query> {
Query transform(Query query);
default Query apply(Query query) {
return this.transform(query);
}
}
CompressionQueryTransformer
用于压缩对话历史和后续查询的类
作用:将对话上下文和后续查询合并为一个独立的查询,以捕获对话的核心内容。
适用场景:对话历史较长、后续查询与对话上下文相关
各字段含义:
ChatClient chatClient:用于与 LLM 交互,生成压缩后的查询PromptTemplate promptTemplate:自定义用于生产压缩查询的提示文本
package org.springframework.ai.rag.preretrieval.query.transformation;
import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
public class CompressionQueryTransformer implements QueryTransformer {
private static final Logger logger = LoggerFactory.getLogger(CompressionQueryTransformer.class);
private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("Given the following conversation history and a follow-up query, your task is to synthesize\na concise, standalone query that incorporates the context from the history.\nEnsure the standalone query is clear, specific, and maintains the user's intent.\n\nConversation history:\n{history}\n\nFollow-up query:\n{query}\n\nStandalone query:\n");
private final ChatClient chatClient;
private final PromptTemplate promptTemplate;
public CompressionQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULTPROMPTTEMPLATE;
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, new String[]{"history", "query"});
}
public Query transform(Query query) {
Assert.notNull(query, "query cannot be null");
logger.debug("Compressing conversation history and follow-up query into a standalone query");
String compressedQueryText = this.chatClient.prompt().user((user) -> user.text(this.promptTemplate.getTemplate()).param("history", this.formatConversationHistory(query.history())).param("query", query.text())).call().content();
if (!StringUtils.hasText(compressedQueryText)) {
logger.warn("Query compression result is null/empty. Returning the input query unchanged.");
return query;
} else {
return query.mutate().text(compressedQueryText).build();
}
}
private String formatConversationHistory(List<Message> history) {
return history.isEmpty() ? "" : (String)history.stream().filter((message) -> message.getMessageType().equals(MessageType.USER) || message.getMessageType().equals(MessageType.ASSISTANT)).map((message) -> "%s: %s".formatted(message.getMessageType(), message.getText())).collect(Collectors.joining("\n"));
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private ChatClient.Builder chatClientBuilder;
@Nullable
private PromptTemplate promptTemplate;
private Builder() {
}
public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
return this;
}
public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public CompressionQueryTransformer build() {
return new CompressionQueryTransformer(this.chatClientBuilder, this.promptTemplate);
}
}
}
RewriteQueryTransformer
重写用户查询的类
作用:通过 LLM 优化查询,以便在查询目标系统时获得更好的结果
适用场景:用户查询冗长、模糊、不包含相关信息
各字段含义
PromptTemplate promptTemplate:自定义重写模版ChatClient chatClient:用于与 LLM 进行交互,重写查询String targetSearchSystem:目标系统的名称,用于在提示模板中指定查询的目标系统,默认为“vector store”
package org.springframework.ai.rag.preretrieval.query.transformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
public class RewriteQueryTransformer implements QueryTransformer {
private static final Logger logger = LoggerFactory.getLogger(RewriteQueryTransformer.class);
private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("Given a user query, rewrite it to provide better results when querying a {target}.\nRemove any irrelevant information, and ensure the query is concise and specific.\n\nOriginal query:\n{query}\n\nRewritten query:\n");
private static final String DEFAULTTARGET = "vector store";
private final ChatClient chatClient;
private final PromptTemplate promptTemplate;
private final String targetSearchSystem;
public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, @Nullable String targetSearchSystem) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULTPROMPTTEMPLATE;
this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : "vector store";
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, new String[]{"target", "query"});
}
public Query transform(Query query) {
Assert.notNull(query, "query cannot be null");
logger.debug("Rewriting query to optimize for querying a {}.", this.targetSearchSystem);
String rewrittenQueryText = this.chatClient.prompt().user((user) -> user.text(this.promptTemplate.getTemplate()).param("target", this.targetSearchSystem).param("query", query.text())).call().content();
if (!StringUtils.hasText(rewrittenQueryText)) {
logger.warn("Query rewrite result is null/empty. Returning the input query unchanged.");
return query;
} else {
return query.mutate().text(rewrittenQueryText).build();
}
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private ChatClient.Builder chatClientBuilder;
@Nullable
private PromptTemplate promptTemplate;
@Nullable
private String targetSearchSystem;
private Builder() {
}
public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
return this;
}
public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public Builder targetSearchSystem(String targetSearchSystem) {
this.targetSearchSystem = targetSearchSystem;
return this;
}
public RewriteQueryTransformer build() {
return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem);
}
}
}
TranslationQueryTransformer
将用户查询翻译为目标语言的工具类
作用:使用 LLM 将用户查询翻译为目标语言
适用场景:当嵌入模型仅支持特定语言,而用户查询使用不同语言时
各字段含义
ChatClient chatClient:与 LLM 交互,翻译为目标语言PromptTemplate promptTemplate:自定义翻译请求的提示模版String targetLanguage:目标语言
package org.springframework.ai.rag.preretrieval.query.transformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
public final class TranslationQueryTransformer implements QueryTransformer {
private static final Logger logger = LoggerFactory.getLogger(TranslationQueryTransformer.class);
private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("Given a user query, translate it to {targetLanguage}.\nIf the query is already in {targetLanguage}, return it unchanged.\nIf you don't know the language of the query, return it unchanged.\nDo not add explanations nor any other text.\n\nOriginal query: {query}\n\nTranslated query:\n");
private final ChatClient chatClient;
private final PromptTemplate promptTemplate;
private final String targetLanguage;
public TranslationQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, String targetLanguage) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
Assert.hasText(targetLanguage, "targetLanguage cannot be null or empty");
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULTPROMPTTEMPLATE;
this.targetLanguage = targetLanguage;
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, new String[]{"targetLanguage", "query"});
}
public Query transform(Query query) {
Assert.notNull(query, "query cannot be null");
logger.debug("Translating query to target language: {}", this.targetLanguage);
String translatedQueryText = this.chatClient.prompt().user((user) -> user.text(this.promptTemplate.getTemplate()).param("targetLanguage", this.targetLanguage).param("query", query.text())).call().content();
if (!StringUtils.hasText(translatedQueryText)) {
logger.warn("Query translation result is null/empty. Returning the input query unchanged.");
return query;
} else {
return query.mutate().text(translatedQueryText).build();
}
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private ChatClient.Builder chatClientBuilder;
@Nullable
private PromptTemplate promptTemplate;
private String targetLanguage;
private Builder() {
}
public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
return this;
}
public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public Builder targetLanguage(String targetLanguage) {
this.targetLanguage = targetLanguage;
return this;
}
public TranslationQueryTransformer build() {
return new TranslationQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetLanguage);
}
}
}
Retrieval
DocumentRetriever(文档检索通用接口)
package org.springframework.ai.rag.retrieval.search;
import java.util.List;
import java.util.function.Function;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
public interface DocumentRetriever extends Function<Query, List<Document>> {
List<Document> retrieve(Query query);
default List<Document> apply(Query query) {
return this.retrieve(query);
}
}
VectorStoreDocumentRetriever
用于从 VectorStore 中检索与输入查询语义相似的文档
各字段含义
VectorStore vectorStore:存储和检索文档的向量存储实例Double similarityThreshold:相似度阈值,过滤相似度低于该值的文档Integer topK:返回文档的上限Supplier<Filter.Expression> filterExpression:运行时根据上下文动态生成过滤条件
package org.springframework.ai.rag.retrieval.search;
import java.util.List;
import java.util.function.Supplier;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
public final class VectorStoreDocumentRetriever implements DocumentRetriever {
public static final String FILTEREXPRESSION = "vectorstorefilterexpression";
private final VectorStore vectorStore;
private final Double similarityThreshold;
private final Integer topK;
private final Supplier<Filter.Expression> filterExpression;
public VectorStoreDocumentRetriever(VectorStore vectorStore, @Nullable Double similarityThreshold, @Nullable Integer topK, @Nullable Supplier<Filter.Expression> filterExpression) {
Assert.notNull(vectorStore, "vectorStore cannot be null");
Assert.isTrue(similarityThreshold == null || similarityThreshold >= (double)0.0F, "similarityThreshold must be equal to or greater than 0.0");
Assert.isTrue(topK == null || topK > 0, "topK must be greater than 0");
this.vectorStore = vectorStore;
this.similarityThreshold = similarityThreshold != null ? similarityThreshold : (double)0.0F;
this.topK = topK != null ? topK : 4;
this.filterExpression = filterExpression != null ? filterExpression : () -> null;
}
public List<Document> retrieve(Query query) {
Assert.notNull(query, "query cannot be null");
Filter.Expression requestFilterExpression = this.computeRequestFilterExpression(query);
SearchRequest searchRequest = SearchRequest.builder().query(query.text()).filterExpression(requestFilterExpression).similarityThreshold(this.similarityThreshold).topK(this.topK).build();
return this.vectorStore.similaritySearch(searchRequest);
}
private Filter.Expression computeRequestFilterExpression(Query query) {
Object contextFilterExpression = query.context().get("vectorstorefilterexpression");
if (contextFilterExpression != null) {
if (contextFilterExpression instanceof Filter.Expression) {
return (Filter.Expression)contextFilterExpression;
}
if (StringUtils.hasText(contextFilterExpression.toString())) {
return (new FilterExpressionTextParser()).parse(contextFilterExpression.toString());
}
}
return (Filter.Expression)this.filterExpression.get();
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private VectorStore vectorStore;
private Double similarityThreshold;
private Integer topK;
private Supplier<Filter.Expression> filterExpression;
private Builder() {
}
public Builder vectorStore(VectorStore vectorStore) {
this.vectorStore = vectorStore;
return this;
}
public Builder similarityThreshold(Double similarityThreshold) {
this.similarityThreshold = similarityThreshold;
return this;
}
public Builder topK(Integer topK) {
this.topK = topK;
return this;
}
public Builder filterExpression(Filter.Expression filterExpression) {
this.filterExpression = () -> filterExpression;
return this;
}
public Builder filterExpression(Supplier<Filter.Expression> filterExpression) {
this.filterExpression = filterExpression;
return this;
}
public VectorStoreDocumentRetriever build() {
return new VectorStoreDocumentRetriever(this.vectorStore, this.similarityThreshold, this.topK, this.filterExpression);
}
}
}
DocumentJoiner(文档统一接口类)
将基于多个查询和多个数据源检索的文档合并为一个单一的文档集合
作用:文档合并(将不同数据源检索的文档合并为一个);去重处理(合并过程中,处理重复文档);排名策略(支持对合并后的文档进行排名处理)
适用场景:从多个查询或多个数据源检索文档,并将结果合并为一个统一集合的场景
package org.springframework.ai.rag.retrieval.join;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
public interface DocumentJoiner extends Function<Map<Query, List<List<Document>>>, List<Document>> {
List<Document> join(Map<Query, List<List<Document>>> documentsForQuery);
default List<Document> apply(Map<Query, List<List<Document>>> documentsForQuery) {
return this.join(documentsForQuery);
}
}
ConcatenationDocumentJoiner
合并基于多个查询和多个数据源检索到的文档
package org.springframework.ai.rag.retrieval.join;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.util.Assert;
public class ConcatenationDocumentJoiner implements DocumentJoiner {
private static final Logger logger = LoggerFactory.getLogger(ConcatenationDocumentJoiner.class);
public List<Document> join(Map<Query, List<List<Document>>> documentsForQuery) {
Assert.notNull(documentsForQuery, "documentsForQuery cannot be null");
Assert.noNullElements(documentsForQuery.keySet(), "documentsForQuery cannot contain null keys");
Assert.noNullElements(documentsForQuery.values(), "documentsForQuery cannot contain null values");
logger.debug("Joining documents by concatenation");
return new ArrayList(((Map)documentsForQuery.values().stream().flatMap(Collection::stream).flatMap(Collection::stream).collect(Collectors.toMap(Document::getId, Function.identity(), (existing, duplicate) -> existing))).values().stream().sorted(Comparator.comparingDouble((doc) -> doc.getScore() != null ? doc.getScore() : (double)0.0F).reversed()).toList());
}
}
Post-Retrieval
DocumentPostProcessor
检索后,对文档进行逻辑出现,如压缩、排名、选择部分等,通过实现该接口
package org.springframework.ai.rag.postretrieval.document;
import java.util.List;
import java.util.function.BiFunction;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
public interface DocumentPostProcessor extends BiFunction<Query, List<Document>, List<Document>> {
List<Document> process(Query query, List<Document> documents);
default List<Document> apply(Query query, List<Document> documents) {
return this.process(query, documents);
}
}
Generation
QueryAugmenter(查询增强接口类)
通过将用户查询与额外的上下文数据结合,从而为 LLM 提供更丰富的背景信息
package org.springframework.ai.rag.generation.augmentation;
import java.util.List;
import java.util.function.BiFunction;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
public interface QueryAugmenter extends BiFunction<Query, List<Document>, Query> {
Query augment(Query query, List<Document> documents);
default Query apply(Query query, List<Document> documents) {
return this.augment(query, documents);
}
}
ContextualQueryAugmenter
增强用户查询的类,通过将用户查询与提供的文档内容结合,生成一个增强后的查询,为后续的 RAG 流程提供更丰富的背景信息
各字段的含义
PromptTemplate promptTemplate:用户自定义提示模版,用于生成增强查询PromptTemplate emptyContextPromptTemplate:用户自定义为空时的上下文提示模版boolean allowEmptyContext:是否允许空上下文Function<List<Document>, String> documentFormatter:用户自定义的文档格式化函数
package org.springframework.ai.rag.generation.augmentation;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
public final class ContextualQueryAugmenter implements QueryAugmenter {
private static final Logger logger = LoggerFactory.getLogger(ContextualQueryAugmenter.class);
private static final PromptTemplate DEFAULTPROMPTTEMPLATE = new PromptTemplate("Context information is below.\n\n---------------------\n{context}\n---------------------\n\nGiven the context information and no prior knowledge, answer the query.\n\nFollow these rules:\n\n1. If the answer is not in the context, just say that you don't know.\n2. Avoid statements like \"Based on the context...\" or \"The provided information...\".\n\nQuery: {query}\n\nAnswer:\n");
private static final PromptTemplate DEFAULTEMPTYCONTEXTPROMPTTEMPLATE = new PromptTemplate("The user query is outside your knowledge base.\nPolitely inform the user that you can't answer it.\n");
private static final boolean DEFAULTALLOWEMPTYCONTEXT = false;
private static final Function<List<Document>, String> DEFAULTDOCUMENTFORMATTER = (documents) -> (String)documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
private final PromptTemplate promptTemplate;
private final PromptTemplate emptyContextPromptTemplate;
private final boolean allowEmptyContext;
private final Function<List<Document>, String> documentFormatter;
public ContextualQueryAugmenter(@Nullable PromptTemplate promptTemplate, @Nullable PromptTemplate emptyContextPromptTemplate, @Nullable Boolean allowEmptyContext, @Nullable Function<List<Document>, String> documentFormatter) {
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULTPROMPTTEMPLATE;
this.emptyContextPromptTemplate = emptyContextPromptTemplate != null ? emptyContextPromptTemplate : DEFAULTEMPTYCONTEXTPROMPTTEMPLATE;
this.allowEmptyContext = allowEmptyContext != null ? allowEmptyContext : false;
this.documentFormatter = documentFormatter != null ? documentFormatter : DEFAULTDOCUMENTFORMATTER;
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, new String[]{"query", "context"});
}
public Query augment(Query query, List<Document> documents) {
Assert.notNull(query, "query cannot be null");
Assert.notNull(documents, "documents cannot be null");
logger.debug("Augmenting query with contextual data");
if (documents.isEmpty()) {
return this.augmentQueryWhenEmptyContext(query);
} else {
String documentContext = (String)this.documentFormatter.apply(documents);
Map<String, Object> promptParameters = Map.of("query", query.text(), "context", documentContext);
return new Query(this.promptTemplate.render(promptParameters));
}
}
private Query augmentQueryWhenEmptyContext(Query query) {
if (this.allowEmptyContext) {
logger.debug("Empty context is allowed. Returning the original query.");
return query;
} else {
logger.debug("Empty context is not allowed. Returning a specific query for empty context.");
return new Query(this.emptyContextPromptTemplate.render());
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private PromptTemplate promptTemplate;
private PromptTemplate emptyContextPromptTemplate;
private Boolean allowEmptyContext;
private Function<List<Document>, String> documentFormatter;
public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}
public Builder emptyContextPromptTemplate(PromptTemplate emptyContextPromptTemplate) {
this.emptyContextPromptTemplate = emptyContextPromptTemplate;
return this;
}
public Builder allowEmptyContext(Boolean allowEmptyContext) {
this.allowEmptyContext = allowEmptyContext;
return this;
}
public Builder documentFormatter(Function<List<Document>, String> documentFormatter) {
this.documentFormatter = documentFormatter;
return this;
}
public ContextualQueryAugmenter build() {
return new ContextualQueryAugmenter(this.promptTemplate, this.emptyContextPromptTemplate, this.allowEmptyContext, this.documentFormatter);
}
}
}
学习交流圈
你好,我是影子,曾先后在🐻、新能源、老铁就职,现在是一名AI研发工程师。目前新建了一个交流群,一个人走得快,一群人走得远,另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取
题外篇
持续贡献,现在已经被提名为Spring AI Alibaba开源社区的Committers了,让我们一起为AI工程添砖加瓦~