第四章:封装 RESTful API
4.1 引言
在前三章中,我们已经完成了 RAG 系统的核心功能:第一章介绍了 RAG 的概念和开发环境,第二章实现了知识库构建和向量化,第三章整合了检索和生成逻辑,构建了一个完整的 RAG 流程。现在,我们需要将 RAG 系统封装为 RESTful API,以便将其功能暴露给外部应用,特别是 Spring Boot 项目。
本章的目标是帮助 Java 程序员实现以下内容:
- 使用 FastAPI 开发 RESTful API,封装 RAG 系统的查询、检索和生成功能。
- 设计清晰的 API 端点,支持多种查询场景。
- 实现 API 的输入验证、错误处理和性能优化。
- 测试 API 的功能和可靠性,为后续 Spring Boot 集成做好准备。
我们将通过详细的代码示例和说明,带你完成从 API 设计到部署的完整流程。本章假设你已经按照前三章的说明实现了 RAG 系统,并熟悉 Python 和 RESTful API 的基本概念。
4.2 为什么选择 FastAPI?
FastAPI 是一个现代、高性能的 Python Web 框架,特别适合开发 RESTful API。相比其他框架(如 Flask 或 Django REST Framework),FastAPI 有以下优势:
- 高性能:基于异步 I/O(使用 Starlette 和 Uvicorn),支持高并发。
- 自动文档生成:内置 OpenAPI 支持,生成交互式 Swagger UI 和 ReDoc 文档。
- 类型检查:基于 Pydantic 的数据验证,减少手动校验代码。
- 易于开发:简洁的语法,适合快速原型开发。
- 与 AI 集成:与 LangChain、Hugging Face 等库无缝集成。
对于 RAG 系统,FastAPI 的异步特性和自动文档生成特别适合处理高并发的查询请求,并为 Java 开发者提供清晰的 API 接口文档。
4.3 配置 FastAPI 环境
4.3.1 安装 FastAPI 和依赖
确保已安装 FastAPI 和 Uvicorn(用于运行 API 服务):
pip install fastapi uvicorn python-multipart
fastapi
:核心框架。uvicorn
:ASGI 服务器,用于运行 FastAPI 应用。python-multipart
:支持文件上传(如上传知识库文档)。
4.3.2 项目结构
为了保持代码组织清晰,我们将创建一个新的项目目录,整合 RAG 系统和 API 代码。推荐的目录结构如下:
rag_api/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI 主应用
│ ├── rag_system.py # RAG 系统核心逻辑(基于第三章)
│ ├── models.py # API 数据模型(Pydantic)
│ ├── utils.py # 工具函数
├── rag_knowledge_base/ # 知识库目录(第二章)
├── rag_knowledge_base_index.faiss # FAISS 索引文件
├── requirements.txt # 依赖列表
├── README.md
创建 requirements.txt:
fastapi
uvicorn
python-multipart
langchain
langchain-community
langchain-huggingface
sentence-transformers
faiss-cpu
transformers
torch
PyPDF2
markdown
nltk
安装所有依赖:
pip install -r requirements.txt
4.4 设计 RESTful API
4.4.1 API 端点规划
我们将设计以下 API 端点,覆盖 RAG 系统的核心功能:
-
查询端点(
/query
):- 方法:POST
- 功能:接收用户查询,返回 RAG 生成的回答和检索到的文档。
- 输入:查询文本(JSON 格式)。
- 输出:回答和相关文档。
-
检索端点(
/retrieve
):- 方法:POST
- 功能:仅执行检索,返回与查询相关的文档。
- 输入:查询文本(JSON 格式)。
- 输出:检索到的文档列表。
-
上传文档端点(
/upload
):- 方法:POST
- 功能:上传新文档到知识库,更新向量索引。
- 输入:文件(multipart/form-data)。
- 输出:上传状态。
-
健康检查端点(
/health
):- 方法:GET
- 功能:检查 API 服务是否正常运行。
- 输出:状态信息。
4.4.2 API 数据模型
我们使用 Pydantic 定义 API 的输入和输出模型,确保数据验证和类型安全。
models.py:
from pydantic import BaseModel
from typing import List, Optional
class QueryRequest(BaseModel):
query: str
top_k: Optional[int] = 3
class DocumentResponse(BaseModel):
file: str
chunk_id: str
content: str
distance: float
class QueryResponse(BaseModel):
question: str
answer: str
retrieved_docs: List[DocumentResponse]
class RetrieveResponse(BaseModel):
query: str
documents: List[DocumentResponse]
class UploadResponse(BaseModel):
status: str
message: str
file_name: Optional[str] = None
class HealthResponse(BaseModel):
status: str
说明:
QueryRequest
:定义查询端点的输入,包含查询文本和可选的 top_k 参数。QueryResponse
:定义查询端点的输出,包含问题、回答和检索文档。DocumentResponse
:定义单个文档的结构,包含文件名、片段 ID、内容和距离。RetrieveResponse
:定义检索端点的输出,仅包含检索文档。UploadResponse
:定义上传端点的输出,包含状态和消息。HealthResponse
:定义健康检查端点的输出。
4.5 实现 RAG 系统(复用第三章)
为了避免重复,我们将第三章的 RAGSystem
类稍作修改,适配 API 使用。以下是更新后的 rag_system.py
:
rag_system.py:
import faiss
import torch
import os
import PyPDF2
import markdown
import nltk
from nltk.tokenize import sent_tokenize
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from transformers import pipeline
nltk.download('punkt', quiet=True)
class RAGSystem:
def __init__(self, index_path, knowledge_base_dir, embedding_model_name="sentence-transformers/all-MiniLM-L6-v2", llm_model_name="google/flan-t5-base"):
self.knowledge_base_dir = knowledge_base_dir
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
# 加载或初始化 FAISS 向量存储
if os.path.exists(index_path):
self.vector_store = FAISS.load_local(
index_path,
embeddings=self.embeddings,
allow_dangerous_deserialization=True
)
else:
self.vector_store = None
# 初始化生成模型
self.text_generation_pipeline = pipeline(
"text2text-generation",
model=llm_model_name,
tokenizer=llm_model_name,
max_length=512,
device=0 if torch.cuda.is_available() else -1
)
self.llm = HuggingFacePipeline(pipeline=self.text_generation_pipeline)
# 初始化提示模板
self.prompt_template = """
You are a knowledgeable assistant. Use the provided context to answer the question accurately and concisely. If the context does not contain enough information, say so.
Context:
{context}
Question: {question}
Answer:
"""
self.prompt = PromptTemplate(
input_variables=["context", "question"],
template=self.prompt_template
)
# 创建 RAG 链
if self.vector_store:
self.rag_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
chain_type_kwargs={"prompt": self.prompt}
)
else:
self.rag_chain = None
def query(self, question, top_k=3):
if not self.rag_chain:
return {"question": question, "answer": "Knowledge base not initialized.", "retrieved_docs": []}
response = self.rag_chain.invoke({"query": question})
retrieved_docs = [
{
"file": doc.metadata.get("file", ""),
"chunk_id": doc.metadata.get("chunk_id", ""),
"content": doc.page_content,
"distance": doc.metadata.get("distance", 0.0)
}
for doc in response.get("source_documents", [])
]
return {
"question": question,
"answer": response["result"],
"retrieved_docs": retrieved_docs
}
def retrieve(self, query, top_k=3):
if not self.vector_store:
return {"query": query, "documents": []}
docs = self.vector_store.similarity_search_with_score(query, k=top_k)
return {
"query": query,
"documents": [
{
"file": doc.metadata.get("file", ""),
"chunk_id": doc.metadata.get("chunk_id", ""),
"content": doc.page_content,
"distance": float(score)
}
for doc, score in docs
]
}
def read_txt_file(self, file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
def read_pdf_file(self, file_path):
text = ""
with open(file_path, 'rb') as f:
pdf = PyPDF2.PdfReader(f)
for page in pdf.pages:
text += page.extract_text() + "\n"
return text
def read_md_file(self, file_path):
with open(file_path, 'r', encoding='utf-8') as f:
md_text = f.read()
html = markdown.markdown(md_text)
return html.replace('<p>', '').replace('</p>', '\n').strip()
def split_text(self, text, max_length=500):
sentences = sent_tokenize(text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < max_length:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def add_document(self, file_path, file_name):
# 读取文件
if file_name.endswith('.txt'):
content = self.read_txt_file(file_path)
elif file_name.endswith('.pdf'):
content = self.read_pdf_file(file_path)
elif file_name.endswith('.md'):
content = self.read_md_file(file_path)
else:
return False, "Unsupported file format."
# 分割文本
chunks = self.split_text(content)
documents = [
{
"page_content": chunk,
"metadata": {"file": file_name, "chunk_id": f"{file_name}_{i}"}
}
for i, chunk in enumerate(chunks)
]
# 更新向量存储
if self.vector_store is None:
self.vector_store = FAISS.from_texts(
texts=[doc["page_content"] for doc in documents],
embedding=self.embeddings,
metadatas=[doc["metadata"] for doc in documents]
)
else:
self.vector_store.add_texts(
texts=[doc["page_content"] for doc in documents],
metadatas=[doc["metadata"] for doc in documents]
)
# 更新 RAG 链
self.rag_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
chain_type_kwargs={"prompt": self.prompt}
)
return True, "Document added successfully."
代码说明:
RAGSystem
类整合了第三章的功能,并新增了add_document
方法,用于处理上传的文档。query
和retrieve
方法返回格式化的结果,适配 API 输出。- 文件处理逻辑(
read_txt_file
等)复用了第二章的代码。
4.6 实现 FastAPI 应用
现在,我们将实现 FastAPI 主应用,定义所有端点。
main.py:
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from .models import QueryRequest, QueryResponse, RetrieveResponse, UploadResponse, HealthResponse, DocumentResponse
from .rag_system import RAGSystem
import os
import shutil
app = FastAPI(title="RAG API", description="RESTful API for Retrieval-Augmented Generation")
# 初始化 RAG 系统
rag_system = RAGSystem(
index_path="rag_knowledge_base_index.faiss",
knowledge_base_dir="rag_knowledge_base"
)
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""检查 API 服务状态"""
return HealthResponse(status="healthy")
@app.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
"""处理 RAG 查询"""
try:
result = rag_system.query(request.query, top_k=request.top_k)
return QueryResponse(
question=result["question"],
answer=result["answer"],
retrieved_docs=[
DocumentResponse(**doc) for doc in result["retrieved_docs"]
]
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Query error: {str(e)}")
@app.post("/retrieve", response_model=RetrieveResponse)
async def retrieve(request: QueryRequest):
"""仅执行检索"""
try:
result = rag_system.retrieve(request.query, top_k=request.top_k)
return RetrieveResponse(
query=result["query"],
documents=[
DocumentResponse(**doc) for doc in result["documents"]
]
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Retrieve error: {str(e)}")
@app.post("/upload", response_model=UploadResponse)
async def upload_file(file: UploadFile = File(...)):
"""上传文档到知识库"""
try:
# 检查文件格式
if not file.filename.endswith(('.txt', '.pdf', '.md')):
return UploadResponse(
status="error",
message="Unsupported file format. Only .txt, .pdf, .md are allowed."
)
# 保存文件
file_path = os.path.join("rag_knowledge_base", file.filename)
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
# 添加到知识库
success, message = rag_system.add_document(file_path, file.filename)
if not success:
return UploadResponse(status="error", message=message)
# 保存更新后的 FAISS 索引
rag_system.vector_store.save_local("rag_knowledge_base_index.faiss")
return UploadResponse(
status="success",
message=message,
file_name=file.filename
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Upload error: {str(e)}")
代码说明:
- FastAPI 应用:定义了
/health
、/query
、/retrieve
和/upload
四个端点。 - 输入验证:使用 Pydantic 模型(
QueryRequest
等)自动验证输入。 - 错误处理:使用
HTTPException
处理异常,返回清晰的错误信息。 - 文件上传:支持上传
.txt
、.pdf
和.md
文件,自动更新知识库和 FAISS 索引。
4.7 运行和测试 API
4.7.1 运行 FastAPI 服务
在 rag_api
目录下运行以下命令启动 API:
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
--host 0.0.0.0
:允许外部访问。--port 8000
:指定端口。--reload
:开发模式,代码更改后自动重启。
访问 http://localhost:8000/docs
,你将看到自动生成的 Swagger UI,包含所有端点的交互式文档。
4.7.2 测试 API
我们将使用 Python 的 requests
库测试 API 端点。
安装 requests:
pip install requests
测试脚本(test_api.py
):
import requests
import json
base_url = "http://localhost:8000"
def test_health():
response = requests.get(f"{base_url}/health")
print("Health Check:", response.json())
def test_query():
payload = {
"query": "Which operating systems are supported by the product?",
"top_k": 3
}
response = requests.post(f"{base_url}/query", json=payload)
print("Query Response:", json.dumps(response.json(), indent=2, ensure_ascii=False))
def test_retrieve():
payload = {
"query": "How to contact technical support?",
"top_k": 2
}
response = requests.post(f"{base_url}/retrieve", json=payload)
print("Retrieve Response:", json.dumps(response.json(), indent=2, ensure_ascii=False))
def test_upload():
file_path = "test_doc.txt"
with open(file_path, "w", encoding="utf-8") as f:
f.write("Test document content.")
with open(file_path, "rb") as f:
files = {"file": (file_path, f, "text/plain")}
response = requests.post(f"{base_url}/upload", files=files)
print("Upload Response:", json.dumps(response.json(), indent=2, ensure_ascii=False))
if __name__ == "__main__":
print("Testing API...")
test_health()
test_query()
test_retrieve()
test_upload()
示例输出:
Testing API...
Health Check: {'status': 'healthy'}
Query Response: {
"question": "Which operating systems are supported by the product?",
"answer": "Our product supports Windows 10/11, macOS 12+, and Ubuntu 20.04+.",
"retrieved_docs": [
{
"file": "faq.txt",
"chunk_id": "faq.txt_0",
"content": "Q: 产品支持哪些操作系统? A: 我们的产品支持 Windows 10/11、macOS 12+ 和 Ubuntu 20.04+。",
"distance": 0.1234
},
...
]
}
Retrieve Response: {
"query": "How to contact technical support?",
"documents": [
{
"file": "faq.txt",
"chunk_id": "faq.txt_1",
"content": "Q: 如何联系技术支持? A: 您可以通过邮箱 support@company.com 或电话 123-456-7890 联系我们。",
"distance": 0.0987
},
...
]
}
Upload Response: {
"status": "success",
"message": "Document added successfully.",
"file_name": "test_doc.txt"
}
分析:
- 健康检查:确认 API 服务正常运行。
- 查询端点:返回 RAG 生成的回答和检索文档。
- 检索端点:仅返回相关文档,适合需要单独检索的场景。
- 上传端点:成功上传新文档并更新知识库。
4.8 优化 API 性能
4.8.1 异步处理
FastAPI 支持异步函数,适合处理 I/O 密集型任务(如检索和生成)。我们已经在 main.py
中使用了 async
函数,但可以进一步优化:
-
批量查询:支持一次提交多个查询:
class BatchQueryRequest(BaseModel): queries: List[QueryRequest] @app.post("/batch_query", response_model=List[QueryResponse]) async def batch_query(request: BatchQueryRequest): try: results = [rag_system.query(q.query, top_k=q.top_k) for q in request.queries] return [ QueryResponse( question=r["question"], answer=r["answer"], retrieved_docs=[DocumentResponse(**doc) for doc in r["retrieved_docs"]] ) for r in results ] except Exception as e: raise HTTPException(status_code=500, detail=f"Batch query error: {str(e)}")
-
异步生成:如果生成模型支持异步推理,可以使用异步管道(需要自定义 Hugging Face 集成)。
4.8.2 缓存
对于频繁重复的查询,可以使用缓存减少计算开销。我们使用 fastapi-cache
实现内存缓存。
安装依赖:
pip install fastapi-cache2[redis]
配置缓存(更新 main.py
):
from fastapi import FastAPI
from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache
app = FastAPI(title="RAG API", description="RESTful API for Retrieval-Augmented Generation")
@app.on_event("startup")
async def startup():
FastAPICache.init(InMemoryBackend(), prefix="fastapi-cache")
# ... 其他代码保持不变 ...
@app.post("/query", response_model=QueryResponse)
@cache(expire=3600) # 缓存 1 小时
async def query(request: QueryRequest):
# 保持原逻辑
try:
result = rag_system.query(request.query, top_k=request.top_k)
return QueryResponse(
question=result["question"],
answer=result["answer"],
retrieved_docs=[
DocumentResponse(**doc) for doc in result["retrieved_docs"]
]
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Query error: {str(e)}")
说明:
@cache(expire=3600)
:缓存查询结果 1 小时。InMemoryBackend
:使用内存缓存,生产环境可以替换为 Redis。
4.8.3 错误处理和日志
为了提高 API 的健壮性,我们添加详细的日志记录:
utils.py:
import logging
def setup_logger():
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("rag_api.log"),
logging.StreamHandler()
]
)
return logging.getLogger("RAG_API")
logger = setup_logger()
更新 main.py:
from .utils import logger
@app.post("/query", response_model=QueryResponse)
@cache(expire=3600)
async def query(request: QueryRequest):
logger.info(f"Received query: {request.query}, top_k: {request.top_k}")
try:
result = rag_system.query(request.query, top_k=request.top_k)
logger.info(f"Query successful: {request.query}")
return QueryResponse(
question=result["question"],
answer=result["answer"],
retrieved_docs=[
DocumentResponse(**doc) for doc in result["retrieved_docs"]
]
)
except Exception as e:
logger.error(f"Query error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Query error: {str(e)}")
说明:
- 日志记录查询请求和错误,便于调试和监控。
- 日志同时输出到文件(
rag_api.log
)和控制台。
4.9 常见问题与解决方案
在开发和测试 API 时,Java 程序员可能遇到以下问题:
-
文件上传失败:
-
问题:上传大型 PDF 文件时超时或失败。
-
解决方案:
-
增加 FastAPI 的超时配置:
uvicorn app.main:app --host 0.0.0.0 --port 8000 --timeout-keep-alive 300
-
使用异步文件处理:
async def upload_file(file: UploadFile = File(...)): file_path = os.path.join("rag_knowledge_base", file.filename) async with aiofiles.open(file_path, "wb") as f: content = await file.read() await f.write(content) # 其余逻辑同上
-
-
-
API 响应慢:
-
问题:查询或检索端点响应时间长。
-
解决方案:
- 优化 FAISS 索引(使用
IndexIVFFlat
)。 - 使用 GPU 加速生成模型。
- 启用缓存(参考 4.8.2)。
- 优化 FAISS 索引(使用
-
-
输入验证错误:
-
问题:客户端发送无效的 JSON 数据导致 422 错误。
-
解决方案:
-
检查 Swagger UI 中的输入格式。
-
提供详细的错误消息:
from pydantic import ValidationError @app.post("/query", response_model=QueryResponse) async def query(request: QueryRequest): try: result = rag_system.query(request.query, top_k=request.top_k) return QueryResponse( question=result["question"], answer=result["answer"], retrieved_docs=[DocumentResponse(**doc) for doc in result["retrieved_docs"]] ) except ValidationError as e: logger.error(f"Validation error: {str(e)}") raise HTTPException(status_code=422, detail=f"Invalid input: {str(e)}") except Exception as e: logger.error(f"Query error: {str(e)}") raise HTTPException(status_code=500, detail=f"Query error: {str(e)}")
-
-
-
知识库更新不生效:
-
问题:上传新文档后,查询结果未反映更新。
-
解决方案:
- 确保
vector_store.save_local
正确保存索引。 - 检查
add_document
方法是否正确更新rag_chain
。
- 确保
-
4.10 本章总结
本章详细讲解了如何使用 FastAPI 将 RAG 系统封装为 RESTful API。你现在应该掌握了以下内容:
- 使用 FastAPI 开发 RESTful API,定义查询、检索、上传和健康检查端点。
- 使用 Pydantic 模型进行输入验证和输出格式化。
- 实现文件上传功能,动态更新知识库和 FAISS 索引。
- 优化 API 性能,包括异步处理、缓存和日志记录。
- 测试 API 的功能和可靠性,解决常见问题。
在下一章,我们将实现 Spring Boot 项目,调用 RAG API,实现前后端交互,并讨论部署方案。