10.使用LangServe+FastAPI服务化!
本项目中我们使用的是LangServe+fastAPI实现的服务化,首先要自行安装依赖。
然后在项目根目录之下,新建server.py
编写代码如下:
from langserve import add_routes
from fastapi import FastAPI, Request, HTTPException
from agent import ai_client
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
import os
import re
import uuid
import config
from tools.ai_tools import create_pdf, write_to_file
# FastAPI接口定义
app = FastAPI(
title="Java助手 - LangChain Server",
version="1.0",
description="支持多轮对话和检索增强的Java问答服务",
)
# 使用slowapi进行基于客户端ip的限流配置来创建一个限流器对象
limiter = Limiter(
key_func=get_remote_address, # 使用客户端 IP 作为限流 key
default_limits=["5/minute"], # 默认限流:每分钟最多 5 次
headers_enabled=True, # 返回 X-RateLimit-* 头
storage_uri="redis://localhost:6379" # 使用内存存储(生产建议用 redis://localhost:6379)
)
app.state.limiter = limiter
# FastAPI添加异常处理器
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # 使用 slowapi 内置 handler
# FastAPI添加中间件(限流)
app.add_middleware(SlowAPIMiddleware)
# 跨域配置
app.add_middleware(
# 内置跨域中间件
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 健康检查接口
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
# 请求模型定义
class ChatRequest(BaseModel):
session_id: str
question: str
# 导出pdf请求模型定义
class ExportPdfRequest(BaseModel):
title: str
content: str
filename: str = None
# 导出文件请求模型定义
class ExportFileRequest(BaseModel):
filename: str
content: str
# 安全检查
safe_name_pattern = re.compile(r'^[\w-.]+$')
os.makedirs(config.EXPORT_DIR, exist_ok=True)
# 聊天接口(标准响应)
@app.post("/chat")
# 限流每分钟10次
@limiter.limit("10/minute")
# 异步支持,(返回的是一个python协程对象)
async def chat_endpoint(request: Request, body: ChatRequest):
try:
invoke_config = {"configurable": {"session_id": body.session_id}}
result = await ai_client.CHAIN_WITH_HISTORY.ainvoke(
{"input": body.question},
config=invoke_config
)
# 如果ai返回是一个字典的话
if isinstance(result, dict):
answer = result.get("output") or result.get("response") or str(result)
# 不是字典的话直接转化成字符串
else:
answer = str(result)
return {"response": answer}
except Exception as e:
return {"error": str(e)}
# 流式聊天接口
@app.post("/chat/stream")
# 每分钟15次
@limiter.limit("15/minute")
async def chat_stream(request: Request, body: ChatRequest):
invoke_config = {"configurable": {"session_id": body.session_id}}
async def event_generator():
try:
async for chunk in ai_client.CHAIN_WITH_HISTORY.astream(
{"input": body.question},
config=invoke_config,
stream_mode="values"
):
# 根据 chain 返回结构处理 chunk
text = chunk if isinstance(chunk, str) else (
chunk.get("output") if isinstance(chunk, dict) else str(chunk)
)
if text.strip():
yield f"data: {text}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield f"data: [ERROR] {str(e)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
# 导出 PDF
@app.post("/export/pdf")
@limiter.limit("10/minute")
async def export_pdf(request: Request, body: ExportPdfRequest):
filename = body.filename or f"export-{uuid.uuid4().hex}.pdf"
if not safe_name_pattern.match(filename):
raise HTTPException(status_code=400, detail="非法文件名")
path = os.path.abspath(os.path.join(config.EXPORT_DIR, filename))
export_dir_abs = os.path.abspath(config.EXPORT_DIR)
if not path.startswith(export_dir_abs):
raise HTTPException(status_code=400, detail="非法路径")
result = create_pdf.invoke({
"content": body.content,
"title": body.title,
"filename": path
})
if not result or not result.get("success"):
error_msg = result.get("error", "生成失败") if result else "生成失败"
raise HTTPException(status_code=500, detail=error_msg)
return {"success": True, "filename": os.path.basename(path)}
# 导出普通文件
@app.post("/export/file")
@limiter.limit("20/minute")
async def export_file(request: Request, body: ExportFileRequest):
if not safe_name_pattern.match(body.filename):
raise HTTPException(status_code=400, detail="非法文件名")
path = os.path.abspath(os.path.join(config.EXPORT_DIR, body.filename))
export_dir_abs = os.path.abspath(config.EXPORT_DIR)
if not path.startswith(export_dir_abs):
raise HTTPException(status_code=400, detail="非法路径")
result = write_to_file.invoke({
"content": body.content,
"filename": path
})
if not result or not result.get("success"):
error_msg = result.get("error", "写入失败") if result else "写入失败"
raise HTTPException(status_code=500, detail=error_msg)
return {"success": True, "filename": os.path.basename(path)}
# 下载文件
@app.get("/download/{filename}")
@limiter.limit("60/minute")
async def download_file(request: Request, filename: str):
if not safe_name_pattern.match(filename):
raise HTTPException(status_code=400, detail="非法文件名")
path = os.path.abspath(os.path.join(config.EXPORT_DIR, filename))
if not os.path.exists(path):
raise HTTPException(status_code=404, detail="文件不存在")
return FileResponse(path, filename=filename)
# LangServe 路由
try:
from langserve import add_routes
add_routes(app, ai_client.CHAIN_WITH_HISTORY, path="/chain")
except ImportError:
print("langserve not installed, skipping /chain route.")
代码解释:
- 模块导入
from langserve import add_routes
from fastapi import FastAPI, Request, HTTPException
from agent import ai_client
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
import os
import re
import uuid
import config
from tools.ai_tools import create_pdf, write_to_file
- 引入 FastAPI 构建 Web 服务
- 使用 LangServe 将 LangChain 链暴露为 API
- 从自定义模块
agent导入 AI 客户端ai_client - 使用 Pydantic 定义请求数据模型
- 配置 CORS 跨域支持
- 引入 slowapi 实现基于 IP 的请求频率限制
- 导入系统和工具模块用于文件操作与安全校验
- 创建 FastAPI 应用实例
app = FastAPI(
title="Java助手 - LangChain Server",
version="1.0",
description="支持多轮对话和检索增强的Java问答服务",
)
- 定义 API 的标题、版本和描述信息,用于自动生成文档(如 Swagger UI)
- 配置限流器(slowapi)
limiter = Limiter(
key_func=get_remote_address,
default_limits=["5/minute"],
headers_enabled=True,
storage_uri="redis://localhost:6379"
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
- 使用客户端 IP 地址作为限流键
- 默认所有未显式标注的接口限流为每分钟 5 次
- 启用限流响应头(如 X-RateLimit-Limit)
- 使用 Redis 存储限流状态(生产环境推荐),开发时可设为 None 使用内存
- 注册异常处理器和中间件以启用限流功能
- 配置 CORS 跨域中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
- 允许任意源跨域访问(适用于开发环境)
- 生产环境中应限制为具体可信的前端域名
- 健康检查接口
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
- 提供标准健康检查端点,用于服务监控和容器探活
- 请求数据模型定义
聊天请求模型
class ChatRequest(BaseModel):
session_id: str
question: str
session_id用于维护多轮对话上下文question为用户输入的问题
导出 PDF 请求模型
class ExportPdfRequest(BaseModel):
title: str
content: str
filename: str = None
title和content用于生成 PDF 内容filename可选,若未提供则自动生成
导出普通文件请求模型
class ExportFileRequest(BaseModel):
filename: str
content: str
- 指定文件名和内容,用于写入文本文件
- 安全与路径初始化
safe_name_pattern = re.compile(r'^[\w-.]+$')
os.makedirs(config.EXPORT_DIR, exist_ok=True)
- 正则表达式限制文件名仅包含字母、数字、下划线、连字符和点号
- 确保导出目录存在,避免写入失败
- 标准聊天接口
/chat
@app.post("/chat")
@limiter.limit("10/minute")
async def chat_endpoint(request: Request, body: ChatRequest):
try:
invoke_config = {"configurable": {"session_id": body.session_id}}
result = await ai_client.CHAIN_WITH_HISTORY.ainvoke(
{"input": body.question},
config=invoke_config
)
if isinstance(result, dict):
answer = result.get("output") or result.get("response") or str(result)
else:
answer = str(result)
return {"response": answer}
except Exception as e:
return {"error": str(e)}
- 每 IP 每分钟最多 10 次请求
- 使用
session_id维护对话历史 - 异步调用 AI 链并处理返回结果
- 统一错误捕获并返回错误信息
- 流式聊天接口
/chat/stream
@app.post("/chat/stream")
@limiter.limit("15/minute")
async def chat_stream(request: Request, body: ChatRequest):
invoke_config = {"configurable": {"session_id": body.session_id}}
async def event_generator():
try:
async for chunk in ai_client.CHAIN_WITH_HISTORY.astream(
{"input": body.question},
config=invoke_config,
stream_mode="values"
):
text = chunk if isinstance(chunk, str) else (
chunk.get("output") if isinstance(chunk, dict) else str(chunk)
)
if text.strip():
yield f"data: {text}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield f"data: [ERROR] {str(e)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
- 支持 Server-Sent Events (SSE) 流式响应
- 每次生成的内容以
data: ...格式推送 - 结束标记为
[DONE],错误标记为[ERROR] - 限流策略更宽松(15次/分钟)
- 导出 PDF 接口
/export/pdf
@app.post("/export/pdf")
@limiter.limit("10/minute")
async def export_pdf(request: Request, body: ExportPdfRequest):
filename = body.filename or f"export-{uuid.uuid4().hex}.pdf"
if not safe_name_pattern.match(filename):
raise HTTPException(status_code=400, detail="非法文件名")
path = os.path.abspath(os.path.join(config.EXPORT_DIR, filename))
export_dir_abs = os.path.abspath(config.EXPORT_DIR)
if not path.startswith(export_dir_abs):
raise HTTPException(status_code=400, detail="非法路径")
result = create_pdf.invoke({
"content": body.content,
"title": body.title,
"filename": path
})
if not result or not result.get("success"):
error_msg = result.get("error", "生成失败") if result else "生成失败"
raise HTTPException(status_code=500, detail=error_msg)
return {"success": True, "filename": os.path.basename(path)}
- 文件名安全校验
- 路径遍历防护:确保目标路径在
EXPORT_DIR内 - 调用
create_pdf工具生成 PDF - 返回成功状态和文件名
- 导出普通文件接口
/export/file
@app.post("/export/file")
@limiter.limit("20/minute")
async def export_file(request: Request, body: ExportFileRequest):
if not safe_name_pattern.match(body.filename):
raise HTTPException(status_code=400, detail="非法文件名")
path = os.path.abspath(os.path.join(config.EXPORT_DIR, body.filename))
export_dir_abs = os.path.abspath(config.EXPORT_DIR)
if not path.startswith(export_dir_abs):
raise HTTPException(status_code=400, detail="非法路径")
result = write_to_file.invoke({
"content": body.content,
"filename": path
})
if not result or not result.get("success"):
error_msg = result.get("error", "写入失败") if result else "写入失败"
raise HTTPException(status_code=500, detail=error_msg)
return {"success": True, "filename": os.path.basename(path)}
- 功能与 PDF 导出类似,但用于任意文本内容
- 限流更宽松(20次/分钟)
- 文件下载接口
/download/{filename}
@app.get("/download/{filename}")
@limiter.limit("60/minute")
async def download_file(request: Request, filename: str):
if not safe_name_pattern.match(filename):
raise HTTPException(status_code=400, detail="非法文件名")
path = os.path.abspath(os.path.join(config.EXPORT_DIR, filename))
if not os.path.exists(path):
raise HTTPException(status_code=404, detail="文件不存在")
return FileResponse(path, filename=filename)
- 校验文件名合法性
- 检查文件是否存在
- 使用
FileResponse提供下载,保留原始文件名
- LangServe 路由集成
try:
from langserve import add_routes
add_routes(app, ai_client.CHAIN_WITH_HISTORY, path="/chain")
except ImportError:
print("langserve not installed, skipping /chain route.")
- 自动为
CHAIN_WITH_HISTORY生成标准 LangServe API(如/chain/invoke,/chain/stream) - 若未安装
langserve,跳过注册,不影响主服务运行
该服务提供以下核心功能:
- 健康检查(
/healthz) - 同步问答(
/chat) - 流式问答(
/chat/stream) - 安全的 PDF 与文本文件导出(
/export/pdf,/export/file) - 文件下载(
/download/{filename}) - 可选的 LangServe 原生链接口(
/chain/*)
安全措施包括:
- 文件名与路径校验防止目录遍历
- 基于 IP 的请求频率限制
- 统一的错误处理与 HTTP 状态码返回
该设计适用于需要对话记忆、内容生成与文件导出的 AI 助手类应用。
然后再在项目根目录新建main.py作为启动主文件
if __name__ == "__main__":
import uvicorn
import server
uvicorn.run(server.app, host="localhost", port=8000)
我们来启动项目:
项目启动成功,LangServe默认提供了一个后台调试界面http://localhost:8000/chain/playground/
我们添加一条human消息,打开控制台发现,总是出现一个这样的循环:
Invalid Format: Missing 'Action:' after 'Thought:'Invalid Format: Missing 'Action:' after 'Thought:'
11、历史遗留问题解决!
之前我们说过,ReAct模式使用的是ReAct输出解析器,llm可能会没有严格按照我们所规定的防方式回复,再加上这个解析器又比较严格这时候就会报错,我们来根本性的解决一下这个问题(也算是一个历史的遗留问题吧),我们重写这个输出解析器,并优化提示词,完整ai_agent.py代码如下:
# ai_client.py
from langchain.globals import set_llm_cache
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.cache import RedisSemanticCache
from langchain_community.chat_models import ChatTongyi
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate
from langchain_core.runnables import RunnableWithMessageHistory
from langchain_core.tools import tool
from langchain.agents import AgentExecutor
import os
from operator import itemgetter
import config
from message_history import sql_message_history
# 注入 DashScope API Key
if config.DASHSCOPE_API_KEY:
os.environ["DASHSCOPE_API_KEY"] = config.DASHSCOPE_API_KEY
# ========== LLM ==========
llm = ChatTongyi(
model=config.MODEL,
temperature=config.TEMPERATURE,
)
# ========== 知识库:加载 TXT / MD / PDF ==========
all_docs = []
try:
txt_docs = DirectoryLoader(
config.DOC_DIR,
glob="**/*.txt",
loader_cls=TextLoader,
show_progress=True,
use_multithreading=True,
silent_errors=True
).load()
md_docs = DirectoryLoader(
config.DOC_DIR,
glob="**/*.md",
loader_cls=TextLoader,
show_progress=True,
use_multithreading=True,
silent_errors=True
).load()
pdf_docs = DirectoryLoader(
config.DOC_DIR,
glob="**/*.pdf",
loader_cls=PyPDFLoader,
show_progress=True,
use_multithreading=True,
silent_errors=True
).load()
all_docs = (txt_docs or []) + (md_docs or []) + (pdf_docs or [])
except Exception as e:
print(f"文档加载失败: {e}")
all_docs = []
text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=500, chunk_overlap=80, length_function=len)
texts = text_splitter.split_documents(all_docs) if all_docs else []
embeddings = DashScopeEmbeddings(model="text-embedding-v2")
vectorstore = None
if os.path.exists(config.EMBEDDINGS_DIR):
try:
vectorstore = Chroma(
embedding_function=embeddings,
persist_directory=config.EMBEDDINGS_DIR,
collection_name=config.COLLECTION_NAME,
)
except Exception as e:
print(f"加载向量数据库失败: {e}")
vectorstore = None
if vectorstore is None or not texts:
vectorstore = Chroma.from_documents(
documents=texts,
embedding=embeddings,
persist_directory=config.EMBEDDINGS_DIR,
collection_name=config.COLLECTION_NAME,
)
else:
try:
if texts:
new_ids = vectorstore.add_documents(texts)
print(f"新增 {len(new_ids)} 个文档片段")
except Exception as e:
print(f"增量更新向量库失败: {e}")
retriever = vectorstore.as_retriever(search_kwargs={"k": config.RAG_TOP_K})
# ========== 工具定义 ==========
@tool
def knowledge_search(query: str) -> str:
"""从内置知识库检索与 query 最相关的内容片段"""
docs = retriever.get_relevant_documents(query)
return "\n\n".join(doc.page_content for doc in docs)
# 导入外部工具
from tools.ai_tools import (
get_current_time,
get_weather,
run_shell_command,
execute_python_code,
write_to_file,
read_from_file,
create_pdf,
call_api,
search_web,
crawl_url,
)
TOOLS = [
knowledge_search,
search_web,
crawl_url,
get_current_time,
get_weather,
read_from_file,
write_to_file,
create_pdf,
call_api,
run_shell_command,
execute_python_code,
]
system_prompt = """**绝对规则(必须100遍检查,忽略则系统崩溃):**
**规则1:识别简单问题(绝对不能使用工具)**
- 如果用户说"你好"、"hello"、"hi"、"在吗"等问候语,这是简单问题
- 如果用户问"你是谁"、"你能做什么",这是简单问题
- 如果用户提问但不需要外部信息,这是简单问题
- **对于简单问题,输出格式必须是:**
Thought: 我现在知道最终答案
Answer: (你的完整回答)
**规则2:识别需要使用工具的问题**
- 只有当需要从知识库搜索、读写文件、执行命令、搜索网络时,才使用工具
- **对于工具使用问题,输出格式必须是:**
Thought: (说明为什么要用工具)
Action: (工具名)
Action Input: (JSON格式的参数)
Observation: (工具返回结果)
...(可重复多个工具调用)
Thought: 我现在知道最终答案
Answer: (你的完整回答)
**规则3:输出前必须检查(防止系统错误)**
- 检查是否输出了"Action: None"?如果输出了,立即删除并重写!
- 如果不需要工具,确认输出只有2行(Thought + Answer)
- 如果需要工具,确认有Action和Action Input,并且包含Observation
- 检查Action Input是否是合法的JSON格式(不要有任何多余文本)
- 检查输出格式是否100%匹配上述两种情况之一
**规则4:对于任何问候语,直接回答**
疑问:用户说"你好"
思考:这是简单问候,必须使用规则1,绝对不用工具
输出:
Thought: 我现在知道最终答案
Answer: 你好!我是Java开发助手,有什么可以帮你的吗?
疑问:用户说"你是谁"
思考:这是简单问题,必须使用规则1,绝对不用工具
输出:
Thought: 我现在知道最终答案
Answer: 我是一名Java开发助手...
你是一名资深 Java 开发 Agent,精通 JDK 8-17、Maven/Gradle、Spring 生态(Spring Boot、Spring MVC、Spring Data JPA、Spring Cloud)、并发与性能调优、JVM 诊断、微服务工程实践。你的目标是在尽可能少的步数内,为用户提供专业、可直接运行的 Java 方案(代码、命令、配置与解释)。
必须遵循:
1. 只使用下方列出的工具;确需外部信息优先用知识库检索,其次再搜索网络。
2. 每次只能使用一个工具;用完等待 Observation 再决定下一步。
3. 回答必须自洽、可执行:给出完整代码需包含必要 import、pom.xml/gradle 配置要给出关键依赖,命令附上执行目录与前置条件。
4. 优先中文回答;代码用合适语言高亮,例如 ```java、```xml、```bash。
5. 安全与最小影响:涉及 shell/文件操作要注明作用与风险,谨慎执行写入/覆盖类操作。
6. 如需求不清,先用 1-3 句澄清关键约束(JDK 版本、构建工具、框架版本、运行环境等)。
Java 解决方案要求:
- 代码质量:命名清晰、边界条件、异常处理、日志与注释(只解释“为什么”)。
- 并发与性能:合理使用线程池、CompletableFuture、锁与无锁结构;避免阻塞;给出复杂度与潜在瓶颈;必要时提供 JMH/压测建议。
- Spring 规范:分层清晰(controller/service/repository)、DTO/VO 转换、事务传播与隔离级别、校验与全局异常处理。
- 数据访问:JPA/Query 方法/Specification/原生 SQL 的取舍;连接池与 N+1 问题规避;分页与索引建议。
- 构建与运行:提供 Maven/Gradle 脚本或命令;测试用例(JUnit5/MockMVC/Mockito)示例;Docker/容器化要点(如需要)。
- JVM 与运维:必要时给出 GC、内存、线程与诊断命令(jcmd/jmap/jstack)、常见参数与可观测性建议。
可用工具:
{tools}
工具使用原则:
- knowledge_search:优先检索本地知识库(如项目内 `doc/` 的最佳实践、调优笔记),再综合回答。
- search_web / crawl_url:仅当知识库不足时使用,并核对来源可靠性。
- read_from_file / write_to_file:仅当用户明确要求“保存/导出/写入/生成文件”时才使用;否则不要调用。写入前在回答里说明将要写入的路径与用途。
- run_shell_command:仅在需要验证构建/测试/脚手架时使用,先说明命令作用与期望输出。
输出格式(ReAct):
格式要求极其严格,必须完全遵循以下规则:
**输出格式有两类情况:**
**情况一:不需要使用任何工具(优先判断这种情况)**
适用场景:
- 用户说"你好"、"hi"、"hello"等问候语
- 用户询问你是谁、你能做什么
- 用户要求解释概念、询问技术问题但不需要外部信息
- 用户要求澄清问题、提供帮助或建议
输出格式:
Thought: 我现在知道最终答案
Answer: [直接给出完整答案,不需要任何Action/Action Input行,不要输出工具调用]
**情况二:需要使用工具**
适用场景:只在你必须获取外部信息、执行命令、读写文件、搜索知识库时使用工具。
输出格式:
Thought: [明确说明为什么必须使用工具,获取什么信息]
Action: [工具名称,必须是: {tool_names}]
Action Input: [JSON格式参数,必须是一个合法的JSON对象]
Observation: [等待工具返回结果]
... (此循环可以重复多次直到获得足够信息)
Thought: 我现在知道最终答案
Answer:
- 简要结论(1-3 句)
- 关键步骤/命令
- 完整代码/配置(如有)
- 验证方法与可能的风险/注意事项
**极其重要的规则(必须遵守):**
1. 对于简单问候,100%直接回答,绝对不要调用任何工具
2. 你只能输出两种格式之一,不能混合
3. 不要输出"Action: None",这是非法的
4. 不要输出"Action Input: None",这是非法的
5. 如果不需要工具,只输出Thought和Answer两行
6. 在最终答案前必须输出"Thought: 我现在知道最终答案"
示例1(问候,不使用工具):
Question: 你好
Thought: 我现在知道最终答案
Answer: 你好!我是Java开发助手,有什么可以帮你的吗?
示例2(使用工具):
Question: 如何创建线程池?
Thought: 我需要从知识库检索线程池的最佳实践
Action: knowledge_search
Action Input: {{"query": "Java线程池最佳实践"}}
Observation: [知识库内容]
Thought: 我现在知道最终答案
Answer: ...
"""
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(system_prompt),
MessagesPlaceholder(variable_name="chat_history"), # 支持历史消息
("human", "{input}\n\n{agent_scratchpad}"),
])
# ========== 创建 Agent(使用自定义解析器) ==========
# 导入自定义ReAct解析器
from agent.custom_react_parser import CustomReActOutputParser
# 创建自定义解析器实例
custom_parser = CustomReActOutputParser()
# 使用create_react_agent创建agent,但传入自定义的output_parser
from langchain.agents import create_react_agent
agent = create_react_agent(
llm=llm,
tools=TOOLS,
prompt=prompt,
output_parser=custom_parser, # <--- 关键:使用自定义解析器
)
# 注意:我们移除了复杂的错误处理,因为自定义解析器已经处理了大部分情况
agent_executor = AgentExecutor(
agent=agent,
tools=TOOLS,
verbose=True, # 建议开发时开启,查看 Agent 决策过程
handle_parsing_errors=True, # 让AgentExecutor捕获剩余的错误
max_iterations=15, # 防止无限循环
early_stopping_method="generate", # 当达到迭代限制时,生成最终答案
return_intermediate_steps=False, # 不返回中间步骤,减少干扰
)
# 语义缓存
try:
if config.REDIS_URL:
redis_cache = RedisSemanticCache(redis_url=config.REDIS_URL, embedding=embeddings)
set_llm_cache(redis_cache)
print("Redis 语义缓存已启用")
except Exception as e:
print(f"Redis 缓存初始化失败: {e}")
# 会话记忆(SQLChatMessageHistory)
def _get_history(session_id: str):
return sql_message_history.get_session_history(session_id)
CHAIN_WITH_HISTORY = RunnableWithMessageHistory(
runnable=agent_executor,
get_session_history=_get_history,
input_messages_key="input",
history_messages_key="chat_history",
)
print("Agent 初始化完成")
对应的自定义输入解析器代码如下:在agent目录下新建custom_react_parser.py
"""
自定义ReAct输出解析器 - 解决LangChain严格解析问题
"""
import re
from typing import Union
from langchain.agents import AgentOutputParser
from langchain.agents.agent import AgentFinish
from langchain.schema import AgentAction
from langchain_core.exceptions import OutputParserException
class CustomReActOutputParser(AgentOutputParser):
"""
自定义的ReAct输出解析器,功能增强:
1. 智能处理不需要使用工具的输出(如问候语)
2. 捕获'Missing Action:'错误并自动转换为AgentFinish
3. 更宽松的格式检查
"""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
# 移除首尾空白
text = text.strip()
# 检查是否是使用工具的情况(包含Action)
includes_action = "Action:" in text
includes_answer = "Answer:" in text
# 如果包含"我现在知道最终答案",说明要返回最终答案
if "我现在知道最终答案" in text and includes_answer:
# 提取Answer内容(在Answer:之后的所有内容)
answer_match = re.search(r"Answer:(.*)$", text, re.DOTALL)
if answer_match:
answer = answer_match.group(1).strip()
else:
answer = text.split("Answer:")[-1].strip()
return AgentFinish(
return_values={"output": answer},
log=text,
)
# 如果包含Action和Action Input,按标准ReAct格式解析
if includes_action and "Action Input:" in text:
try:
# 提取Action
action_match = re.search(r"Action:\s*(.+?)(?:\n|$)", text, re.IGNORECASE)
# 提取Action Input
action_input_match = re.search(
r"Action Input:\s*{(.+)}", text, re.DOTALL
)
if not action_match:
raise OutputParserException(
f"Could not parse LLM output: `{text}`",
observation="Invalid Format: Could not parse Action",
llm_output=text,
send_to_llm=True,
)
if not action_input_match:
raise OutputParserException(
f"Could not parse LLM output: `{text}`",
observation="Invalid Format: Could not parse Action Input",
llm_output=text,
send_to_llm=True,
)
action = action_match.group(1).strip()
action_input_str = "{" + action_input_match.group(1).strip() + "}"
return AgentAction(action, action_input_str, text)
except Exception as e:
# 如果解析失败,检查是否是因为不需要工具
if "我现在知道最终答案" in text and not includes_action:
# 这是不需要工具的情况,直接返回最终答案
return AgentFinish(
return_values={"output": text.split("Answer:")[-1].strip() if "Answer:" in text else text},
log=text,
)
raise OutputParserException(
f"Could not parse LLM output: `{text}`",
observation=f"Invalid Format: {str(e)}",
llm_output=text,
send_to_llm=True,
)
# 如果只有Thought和Answer(没有Action),认为是最终答案
if "Thought:" in text and includes_answer:
answer_match = re.search(r"Answer:(.*)$", text, re.DOTALL)
if answer_match:
answer = answer_match.group(1).strip()
return AgentFinish(
return_values={"output": answer},
log=text,
)
# 如果没有Action,但有"我现在知道最终答案",也认为是最终答案
if "我现在知道最终答案" in text:
# 提取Answer部分
if "Answer:" in text:
answer = text.split("Answer:")[-1].strip()
else:
# 如果没有Answer,取Thought之后的所有内容
answer = text.split("Thought:")[-1].strip() if "Thought:" in text else text
# 移除"我现在知道最终答案"
answer = answer.replace("我现在知道最终答案", "").strip()
return AgentFinish(
return_values={"output": answer},
log=text,
)
# 如果既不是有效的Action格式,也没有明确答案,尝试兜底处理
# 但如果是Missing Action错误,尝试挽救
if "Thought:" in text and not includes_action:
# LLM可能只输出了Thought,没有Action,也没有明确说"我现在知道最终答案"
# 这可能是问候语,尝试提取内容并作为最终答案
thought_content = text.split("Thought:")[-1].strip()
if thought_content and len(thought_content) < 200:
# 较短的文本,很可能是问候语
return AgentFinish(
return_values={"output": thought_content},
log=text,
)
# 兜底处理:如果以上所有条件都不符合,直接将整个文本作为最终答案
# 这样可以避免抛出异常导致循环,确保任何输出都能被处理
print(f"[警告] LLM输出格式无法识别,作为普通文本处理: {text[:100]}...")
return AgentFinish(
return_values={"output": text},
log=text,
)
@property
def _type(self) -> str:
return "custom_react"
再来重启项目
现在问题没有了:
我们把这里展开,可以看到agent的执行流程
我们再来测试一下,比如,问一个怎么自定义线程池:
控制台输出如下:
ai已经去知识库中检索了。
注意,这时候,我们可能发现,Redis语义缓存好像没有生效,这是因为Redis的初始化的时候文本嵌入模型还未加载,把Redis初始化提前就好。
可以按照这样的初始化顺序,把需要初始化llm和文本嵌入模型的代码提前,
并删除这一行 embeddings = DashScopeEmbeddings(model="text-embedding-v2") 因为如果不删除就有两个文本嵌入模型实例(我们已经提前初始化文本嵌入模型了)
# 注入 DashScope API Key
if config.DASHSCOPE_API_KEY:
os.environ["DASHSCOPE_API_KEY"] = config.DASHSCOPE_API_KEY
# 通用组件(提前创建)
# 先创建基础组件供后续使用
embeddings = DashScopeEmbeddings(model="text-embedding-v2")
# LLM
llm = ChatTongyi(
model=config.MODEL,
temperature=config.TEMPERATURE,
)
# Redis语义缓存(立即初始化)
# 放在LLM创建之后、文档加载之前,确保所有LLM调用都能被缓存
try:
if config.REDIS_URL:
redis_cache = RedisSemanticCache(redis_url=config.REDIS_URL, embedding=embeddings)
set_llm_cache(redis_cache)
print(f"Redis 语义缓存已启用: {config.REDIS_URL}")
except Exception as e:
print(f"Redis 缓存初始化失败: {e}")
print("继续使用无缓存模式")
再来重启项目测试应该就生效了
总结
至此,我们已经完成了LangChain ReAct智能体的完整开发,包括:
- 环境搭建与项目初始化
- LangChain基础学习
- RAG智能体构建
- 对话记忆实现
- ReAct智能体开发
- 工具开发
您现在拥有一个功能完整的Java开发助手智能体!