第四章:封装 RESTful API

13 阅读12分钟

第四章:封装 RESTful API

4.1 引言

在前三章中,我们已经完成了 RAG 系统的核心功能:第一章介绍了 RAG 的概念和开发环境,第二章实现了知识库构建和向量化,第三章整合了检索和生成逻辑,构建了一个完整的 RAG 流程。现在,我们需要将 RAG 系统封装为 RESTful API,以便将其功能暴露给外部应用,特别是 Spring Boot 项目。

本章的目标是帮助 Java 程序员实现以下内容:

  1. 使用 FastAPI 开发 RESTful API,封装 RAG 系统的查询、检索和生成功能。
  2. 设计清晰的 API 端点,支持多种查询场景。
  3. 实现 API 的输入验证、错误处理和性能优化。
  4. 测试 API 的功能和可靠性,为后续 Spring Boot 集成做好准备。

我们将通过详细的代码示例和说明,带你完成从 API 设计到部署的完整流程。本章假设你已经按照前三章的说明实现了 RAG 系统,并熟悉 Python 和 RESTful API 的基本概念。

4.2 为什么选择 FastAPI?

FastAPI 是一个现代、高性能的 Python Web 框架,特别适合开发 RESTful API。相比其他框架(如 Flask 或 Django REST Framework),FastAPI 有以下优势:

  • 高性能:基于异步 I/O(使用 Starlette 和 Uvicorn),支持高并发。
  • 自动文档生成:内置 OpenAPI 支持,生成交互式 Swagger UI 和 ReDoc 文档。
  • 类型检查:基于 Pydantic 的数据验证,减少手动校验代码。
  • 易于开发:简洁的语法,适合快速原型开发。
  • 与 AI 集成:与 LangChain、Hugging Face 等库无缝集成。

对于 RAG 系统,FastAPI 的异步特性和自动文档生成特别适合处理高并发的查询请求,并为 Java 开发者提供清晰的 API 接口文档。

4.3 配置 FastAPI 环境

4.3.1 安装 FastAPI 和依赖

确保已安装 FastAPI 和 Uvicorn(用于运行 API 服务):

pip install fastapi uvicorn python-multipart
  • fastapi:核心框架。
  • uvicorn:ASGI 服务器,用于运行 FastAPI 应用。
  • python-multipart:支持文件上传(如上传知识库文档)。

4.3.2 项目结构

为了保持代码组织清晰,我们将创建一个新的项目目录,整合 RAG 系统和 API 代码。推荐的目录结构如下:

rag_api/
├── app/
│   ├── __init__.py
│   ├── main.py              # FastAPI 主应用
│   ├── rag_system.py        # RAG 系统核心逻辑(基于第三章)
│   ├── models.py            # API 数据模型(Pydantic)
│   ├── utils.py             # 工具函数
├── rag_knowledge_base/      # 知识库目录(第二章)
├── rag_knowledge_base_index.faiss  # FAISS 索引文件
├── requirements.txt         # 依赖列表
├── README.md

创建 requirements.txt

fastapi
uvicorn
python-multipart
langchain
langchain-community
langchain-huggingface
sentence-transformers
faiss-cpu
transformers
torch
PyPDF2
markdown
nltk

安装所有依赖:

pip install -r requirements.txt

4.4 设计 RESTful API

4.4.1 API 端点规划

我们将设计以下 API 端点,覆盖 RAG 系统的核心功能:

  1. 查询端点/query):

    • 方法:POST
    • 功能:接收用户查询,返回 RAG 生成的回答和检索到的文档。
    • 输入:查询文本(JSON 格式)。
    • 输出:回答和相关文档。
  2. 检索端点/retrieve):

    • 方法:POST
    • 功能:仅执行检索,返回与查询相关的文档。
    • 输入:查询文本(JSON 格式)。
    • 输出:检索到的文档列表。
  3. 上传文档端点/upload):

    • 方法:POST
    • 功能:上传新文档到知识库,更新向量索引。
    • 输入:文件(multipart/form-data)。
    • 输出:上传状态。
  4. 健康检查端点/health):

    • 方法:GET
    • 功能:检查 API 服务是否正常运行。
    • 输出:状态信息。

4.4.2 API 数据模型

我们使用 Pydantic 定义 API 的输入和输出模型,确保数据验证和类型安全。

models.py

from pydantic import BaseModel
from typing import List, Optional

class QueryRequest(BaseModel):
    query: str
    top_k: Optional[int] = 3

class DocumentResponse(BaseModel):
    file: str
    chunk_id: str
    content: str
    distance: float

class QueryResponse(BaseModel):
    question: str
    answer: str
    retrieved_docs: List[DocumentResponse]

class RetrieveResponse(BaseModel):
    query: str
    documents: List[DocumentResponse]

class UploadResponse(BaseModel):
    status: str
    message: str
    file_name: Optional[str] = None

class HealthResponse(BaseModel):
    status: str

说明

  • QueryRequest:定义查询端点的输入,包含查询文本和可选的 top_k 参数。
  • QueryResponse:定义查询端点的输出,包含问题、回答和检索文档。
  • DocumentResponse:定义单个文档的结构,包含文件名、片段 ID、内容和距离。
  • RetrieveResponse:定义检索端点的输出,仅包含检索文档。
  • UploadResponse:定义上传端点的输出,包含状态和消息。
  • HealthResponse:定义健康检查端点的输出。

4.5 实现 RAG 系统(复用第三章)

为了避免重复,我们将第三章的 RAGSystem 类稍作修改,适配 API 使用。以下是更新后的 rag_system.py

rag_system.py

import faiss
import torch
import os
import PyPDF2
import markdown
import nltk
from nltk.tokenize import sent_tokenize
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from transformers import pipeline

nltk.download('punkt', quiet=True)

class RAGSystem:
    def __init__(self, index_path, knowledge_base_dir, embedding_model_name="sentence-transformers/all-MiniLM-L6-v2", llm_model_name="google/flan-t5-base"):
        self.knowledge_base_dir = knowledge_base_dir
        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
        
        # 加载或初始化 FAISS 向量存储
        if os.path.exists(index_path):
            self.vector_store = FAISS.load_local(
                index_path,
                embeddings=self.embeddings,
                allow_dangerous_deserialization=True
            )
        else:
            self.vector_store = None
        
        # 初始化生成模型
        self.text_generation_pipeline = pipeline(
            "text2text-generation",
            model=llm_model_name,
            tokenizer=llm_model_name,
            max_length=512,
            device=0 if torch.cuda.is_available() else -1
        )
        self.llm = HuggingFacePipeline(pipeline=self.text_generation_pipeline)
        
        # 初始化提示模板
        self.prompt_template = """
        You are a knowledgeable assistant. Use the provided context to answer the question accurately and concisely. If the context does not contain enough information, say so.
        
        Context:
        {context}
        
        Question: {question}
        
        Answer:
        """
        self.prompt = PromptTemplate(
            input_variables=["context", "question"],
            template=self.prompt_template
        )
        
        # 创建 RAG 链
        if self.vector_store:
            self.rag_chain = RetrievalQA.from_chain_type(
                llm=self.llm,
                chain_type="stuff",
                retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
                chain_type_kwargs={"prompt": self.prompt}
            )
        else:
            self.rag_chain = None

    def query(self, question, top_k=3):
        if not self.rag_chain:
            return {"question": question, "answer": "Knowledge base not initialized.", "retrieved_docs": []}
        
        response = self.rag_chain.invoke({"query": question})
        retrieved_docs = [
            {
                "file": doc.metadata.get("file", ""),
                "chunk_id": doc.metadata.get("chunk_id", ""),
                "content": doc.page_content,
                "distance": doc.metadata.get("distance", 0.0)
            }
            for doc in response.get("source_documents", [])
        ]
        return {
            "question": question,
            "answer": response["result"],
            "retrieved_docs": retrieved_docs
        }

    def retrieve(self, query, top_k=3):
        if not self.vector_store:
            return {"query": query, "documents": []}
        
        docs = self.vector_store.similarity_search_with_score(query, k=top_k)
        return {
            "query": query,
            "documents": [
                {
                    "file": doc.metadata.get("file", ""),
                    "chunk_id": doc.metadata.get("chunk_id", ""),
                    "content": doc.page_content,
                    "distance": float(score)
                }
                for doc, score in docs
            ]
        }

    def read_txt_file(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as f:
            return f.read()

    def read_pdf_file(self, file_path):
        text = ""
        with open(file_path, 'rb') as f:
            pdf = PyPDF2.PdfReader(f)
            for page in pdf.pages:
                text += page.extract_text() + "\n"
        return text

    def read_md_file(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as f:
            md_text = f.read()
            html = markdown.markdown(md_text)
            return html.replace('<p>', '').replace('</p>', '\n').strip()

    def split_text(self, text, max_length=500):
        sentences = sent_tokenize(text)
        chunks = []
        current_chunk = ""
        for sentence in sentences:
            if len(current_chunk) + len(sentence) < max_length:
                current_chunk += sentence + " "
            else:
                chunks.append(current_chunk.strip())
                current_chunk = sentence + " "
        if current_chunk:
            chunks.append(current_chunk.strip())
        return chunks

    def add_document(self, file_path, file_name):
        # 读取文件
        if file_name.endswith('.txt'):
            content = self.read_txt_file(file_path)
        elif file_name.endswith('.pdf'):
            content = self.read_pdf_file(file_path)
        elif file_name.endswith('.md'):
            content = self.read_md_file(file_path)
        else:
            return False, "Unsupported file format."

        # 分割文本
        chunks = self.split_text(content)
        documents = [
            {
                "page_content": chunk,
                "metadata": {"file": file_name, "chunk_id": f"{file_name}_{i}"}
            }
            for i, chunk in enumerate(chunks)
        ]

        # 更新向量存储
        if self.vector_store is None:
            self.vector_store = FAISS.from_texts(
                texts=[doc["page_content"] for doc in documents],
                embedding=self.embeddings,
                metadatas=[doc["metadata"] for doc in documents]
            )
        else:
            self.vector_store.add_texts(
                texts=[doc["page_content"] for doc in documents],
                metadatas=[doc["metadata"] for doc in documents]
            )

        # 更新 RAG 链
        self.rag_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
            chain_type_kwargs={"prompt": self.prompt}
        )

        return True, "Document added successfully."

代码说明

  • RAGSystem 类整合了第三章的功能,并新增了 add_document 方法,用于处理上传的文档。
  • queryretrieve 方法返回格式化的结果,适配 API 输出。
  • 文件处理逻辑(read_txt_file 等)复用了第二章的代码。

4.6 实现 FastAPI 应用

现在,我们将实现 FastAPI 主应用,定义所有端点。

main.py

from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from .models import QueryRequest, QueryResponse, RetrieveResponse, UploadResponse, HealthResponse, DocumentResponse
from .rag_system import RAGSystem
import os
import shutil

app = FastAPI(title="RAG API", description="RESTful API for Retrieval-Augmented Generation")

# 初始化 RAG 系统
rag_system = RAGSystem(
    index_path="rag_knowledge_base_index.faiss",
    knowledge_base_dir="rag_knowledge_base"
)

@app.get("/health", response_model=HealthResponse)
async def health_check():
    """检查 API 服务状态"""
    return HealthResponse(status="healthy")

@app.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
    """处理 RAG 查询"""
    try:
        result = rag_system.query(request.query, top_k=request.top_k)
        return QueryResponse(
            question=result["question"],
            answer=result["answer"],
            retrieved_docs=[
                DocumentResponse(**doc) for doc in result["retrieved_docs"]
            ]
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Query error: {str(e)}")

@app.post("/retrieve", response_model=RetrieveResponse)
async def retrieve(request: QueryRequest):
    """仅执行检索"""
    try:
        result = rag_system.retrieve(request.query, top_k=request.top_k)
        return RetrieveResponse(
            query=result["query"],
            documents=[
                DocumentResponse(**doc) for doc in result["documents"]
            ]
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Retrieve error: {str(e)}")

@app.post("/upload", response_model=UploadResponse)
async def upload_file(file: UploadFile = File(...)):
    """上传文档到知识库"""
    try:
        # 检查文件格式
        if not file.filename.endswith(('.txt', '.pdf', '.md')):
            return UploadResponse(
                status="error",
                message="Unsupported file format. Only .txt, .pdf, .md are allowed."
            )
        
        # 保存文件
        file_path = os.path.join("rag_knowledge_base", file.filename)
        with open(file_path, "wb") as f:
            shutil.copyfileobj(file.file, f)
        
        # 添加到知识库
        success, message = rag_system.add_document(file_path, file.filename)
        if not success:
            return UploadResponse(status="error", message=message)
        
        # 保存更新后的 FAISS 索引
        rag_system.vector_store.save_local("rag_knowledge_base_index.faiss")
        
        return UploadResponse(
            status="success",
            message=message,
            file_name=file.filename
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Upload error: {str(e)}")

代码说明

  • FastAPI 应用:定义了 /health/query/retrieve/upload 四个端点。
  • 输入验证:使用 Pydantic 模型(QueryRequest 等)自动验证输入。
  • 错误处理:使用 HTTPException 处理异常,返回清晰的错误信息。
  • 文件上传:支持上传 .txt.pdf.md 文件,自动更新知识库和 FAISS 索引。

4.7 运行和测试 API

4.7.1 运行 FastAPI 服务

rag_api 目录下运行以下命令启动 API:

uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
  • --host 0.0.0.0:允许外部访问。
  • --port 8000:指定端口。
  • --reload:开发模式,代码更改后自动重启。

访问 http://localhost:8000/docs,你将看到自动生成的 Swagger UI,包含所有端点的交互式文档。

4.7.2 测试 API

我们将使用 Python 的 requests 库测试 API 端点。

安装 requests

pip install requests

测试脚本test_api.py):

import requests
import json

base_url = "http://localhost:8000"

def test_health():
    response = requests.get(f"{base_url}/health")
    print("Health Check:", response.json())

def test_query():
    payload = {
        "query": "Which operating systems are supported by the product?",
        "top_k": 3
    }
    response = requests.post(f"{base_url}/query", json=payload)
    print("Query Response:", json.dumps(response.json(), indent=2, ensure_ascii=False))

def test_retrieve():
    payload = {
        "query": "How to contact technical support?",
        "top_k": 2
    }
    response = requests.post(f"{base_url}/retrieve", json=payload)
    print("Retrieve Response:", json.dumps(response.json(), indent=2, ensure_ascii=False))

def test_upload():
    file_path = "test_doc.txt"
    with open(file_path, "w", encoding="utf-8") as f:
        f.write("Test document content.")
    
    with open(file_path, "rb") as f:
        files = {"file": (file_path, f, "text/plain")}
        response = requests.post(f"{base_url}/upload", files=files)
    print("Upload Response:", json.dumps(response.json(), indent=2, ensure_ascii=False))

if __name__ == "__main__":
    print("Testing API...")
    test_health()
    test_query()
    test_retrieve()
    test_upload()

示例输出

Testing API...
Health Check: {'status': 'healthy'}
Query Response: {
  "question": "Which operating systems are supported by the product?",
  "answer": "Our product supports Windows 10/11, macOS 12+, and Ubuntu 20.04+.",
  "retrieved_docs": [
    {
      "file": "faq.txt",
      "chunk_id": "faq.txt_0",
      "content": "Q: 产品支持哪些操作系统? A: 我们的产品支持 Windows 10/11、macOS 12+ 和 Ubuntu 20.04+。",
      "distance": 0.1234
    },
    ...
  ]
}
Retrieve Response: {
  "query": "How to contact technical support?",
  "documents": [
    {
      "file": "faq.txt",
      "chunk_id": "faq.txt_1",
      "content": "Q: 如何联系技术支持? A: 您可以通过邮箱 support@company.com 或电话 123-456-7890 联系我们。",
      "distance": 0.0987
    },
    ...
  ]
}
Upload Response: {
  "status": "success",
  "message": "Document added successfully.",
  "file_name": "test_doc.txt"
}

分析

  • 健康检查:确认 API 服务正常运行。
  • 查询端点:返回 RAG 生成的回答和检索文档。
  • 检索端点:仅返回相关文档,适合需要单独检索的场景。
  • 上传端点:成功上传新文档并更新知识库。

4.8 优化 API 性能

4.8.1 异步处理

FastAPI 支持异步函数,适合处理 I/O 密集型任务(如检索和生成)。我们已经在 main.py 中使用了 async 函数,但可以进一步优化:

  • 批量查询:支持一次提交多个查询:

    class BatchQueryRequest(BaseModel):
        queries: List[QueryRequest]
    
    @app.post("/batch_query", response_model=List[QueryResponse])
    async def batch_query(request: BatchQueryRequest):
        try:
            results = [rag_system.query(q.query, top_k=q.top_k) for q in request.queries]
            return [
                QueryResponse(
                    question=r["question"],
                    answer=r["answer"],
                    retrieved_docs=[DocumentResponse(**doc) for doc in r["retrieved_docs"]]
                )
                for r in results
            ]
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Batch query error: {str(e)}")
    
  • 异步生成:如果生成模型支持异步推理,可以使用异步管道(需要自定义 Hugging Face 集成)。

4.8.2 缓存

对于频繁重复的查询,可以使用缓存减少计算开销。我们使用 fastapi-cache 实现内存缓存。

安装依赖

pip install fastapi-cache2[redis]

配置缓存(更新 main.py):

from fastapi import FastAPI
from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache

app = FastAPI(title="RAG API", description="RESTful API for Retrieval-Augmented Generation")

@app.on_event("startup")
async def startup():
    FastAPICache.init(InMemoryBackend(), prefix="fastapi-cache")

# ... 其他代码保持不变 ...

@app.post("/query", response_model=QueryResponse)
@cache(expire=3600)  # 缓存 1 小时
async def query(request: QueryRequest):
    # 保持原逻辑
    try:
        result = rag_system.query(request.query, top_k=request.top_k)
        return QueryResponse(
            question=result["question"],
            answer=result["answer"],
            retrieved_docs=[
                DocumentResponse(**doc) for doc in result["retrieved_docs"]
            ]
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Query error: {str(e)}")

说明

  • @cache(expire=3600):缓存查询结果 1 小时。
  • InMemoryBackend:使用内存缓存,生产环境可以替换为 Redis。

4.8.3 错误处理和日志

为了提高 API 的健壮性,我们添加详细的日志记录:

utils.py

import logging

def setup_logger():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        handlers=[
            logging.FileHandler("rag_api.log"),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger("RAG_API")

logger = setup_logger()

更新 main.py

from .utils import logger

@app.post("/query", response_model=QueryResponse)
@cache(expire=3600)
async def query(request: QueryRequest):
    logger.info(f"Received query: {request.query}, top_k: {request.top_k}")
    try:
        result = rag_system.query(request.query, top_k=request.top_k)
        logger.info(f"Query successful: {request.query}")
        return QueryResponse(
            question=result["question"],
            answer=result["answer"],
            retrieved_docs=[
                DocumentResponse(**doc) for doc in result["retrieved_docs"]
            ]
        )
    except Exception as e:
        logger.error(f"Query error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Query error: {str(e)}")

说明

  • 日志记录查询请求和错误,便于调试和监控。
  • 日志同时输出到文件(rag_api.log)和控制台。

4.9 常见问题与解决方案

在开发和测试 API 时,Java 程序员可能遇到以下问题:

  1. 文件上传失败

    • 问题:上传大型 PDF 文件时超时或失败。

    • 解决方案

      • 增加 FastAPI 的超时配置:

        uvicorn app.main:app --host 0.0.0.0 --port 8000 --timeout-keep-alive 300
        
      • 使用异步文件处理:

        async def upload_file(file: UploadFile = File(...)):
            file_path = os.path.join("rag_knowledge_base", file.filename)
            async with aiofiles.open(file_path, "wb") as f:
                content = await file.read()
                await f.write(content)
            # 其余逻辑同上
        
  2. API 响应慢

    • 问题:查询或检索端点响应时间长。

    • 解决方案

      • 优化 FAISS 索引(使用 IndexIVFFlat)。
      • 使用 GPU 加速生成模型。
      • 启用缓存(参考 4.8.2)。
  3. 输入验证错误

    • 问题:客户端发送无效的 JSON 数据导致 422 错误。

    • 解决方案

      • 检查 Swagger UI 中的输入格式。

      • 提供详细的错误消息:

        from pydantic import ValidationError
        
        @app.post("/query", response_model=QueryResponse)
        async def query(request: QueryRequest):
            try:
                result = rag_system.query(request.query, top_k=request.top_k)
                return QueryResponse(
                    question=result["question"],
                    answer=result["answer"],
                    retrieved_docs=[DocumentResponse(**doc) for doc in result["retrieved_docs"]]
                )
            except ValidationError as e:
                logger.error(f"Validation error: {str(e)}")
                raise HTTPException(status_code=422, detail=f"Invalid input: {str(e)}")
            except Exception as e:
                logger.error(f"Query error: {str(e)}")
                raise HTTPException(status_code=500, detail=f"Query error: {str(e)}")
        
  4. 知识库更新不生效

    • 问题:上传新文档后,查询结果未反映更新。

    • 解决方案

      • 确保 vector_store.save_local 正确保存索引。
      • 检查 add_document 方法是否正确更新 rag_chain

4.10 本章总结

本章详细讲解了如何使用 FastAPI 将 RAG 系统封装为 RESTful API。你现在应该掌握了以下内容:

  • 使用 FastAPI 开发 RESTful API,定义查询、检索、上传和健康检查端点。
  • 使用 Pydantic 模型进行输入验证和输出格式化。
  • 实现文件上传功能,动态更新知识库和 FAISS 索引。
  • 优化 API 性能,包括异步处理、缓存和日志记录。
  • 测试 API 的功能和可靠性,解决常见问题。

在下一章,我们将实现 Spring Boot 项目,调用 RAG API,实现前后端交互,并讨论部署方案。