Java程序员必做项目!基于LangChain实现的ReAct智能体项目。助你拿offer!(完结)

12 阅读16分钟

10.使用LangServe+FastAPI服务化!

本项目中我们使用的是LangServe+fastAPI实现的服务化,首先要自行安装依赖。

然后在项目根目录之下,新建server.py

编写代码如下:

from langserve import add_routes
from fastapi import FastAPI, Request, HTTPException
from agent import ai_client
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
​
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
​
import os
import re
import uuid
import config
from tools.ai_tools import create_pdf, write_to_file
​
# FastAPI接口定义
app = FastAPI(
    title="Java助手 - LangChain Server",
    version="1.0",
    description="支持多轮对话和检索增强的Java问答服务",
)
​
# 使用slowapi进行基于客户端ip的限流配置来创建一个限流器对象
limiter = Limiter(
    key_func=get_remote_address,  # 使用客户端 IP 作为限流 key
    default_limits=["5/minute"],  # 默认限流:每分钟最多 5 次
    headers_enabled=True,  # 返回 X-RateLimit-* 头
    storage_uri="redis://localhost:6379"  # 使用内存存储(生产建议用 redis://localhost:6379)
)
app.state.limiter = limiter
# FastAPI添加异常处理器
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)  # 使用 slowapi 内置 handler
# FastAPI添加中间件(限流)
app.add_middleware(SlowAPIMiddleware)
​
# 跨域配置
app.add_middleware(
    # 内置跨域中间件
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
​
​
# 健康检查接口
@app.get("/healthz")
async def healthz():
    return {"status": "ok"}
​
​
# 请求模型定义
class ChatRequest(BaseModel):
    session_id: str
    question: str# 导出pdf请求模型定义
class ExportPdfRequest(BaseModel):
    title: str
    content: str
    filename: str = None# 导出文件请求模型定义
class ExportFileRequest(BaseModel):
    filename: str
    content: str
​
​
# 安全检查
safe_name_pattern = re.compile(r'^[\w-.]+$')
os.makedirs(config.EXPORT_DIR, exist_ok=True)
​
​
# 聊天接口(标准响应)
@app.post("/chat")
# 限流每分钟10次
@limiter.limit("10/minute")
# 异步支持,(返回的是一个python协程对象)
async def chat_endpoint(request: Request, body: ChatRequest):
    try:
        invoke_config = {"configurable": {"session_id": body.session_id}}
        result = await ai_client.CHAIN_WITH_HISTORY.ainvoke(
            {"input": body.question},
            config=invoke_config
        )
        # 如果ai返回是一个字典的话
        if isinstance(result, dict):
            answer = result.get("output") or result.get("response") or str(result)
            # 不是字典的话直接转化成字符串
        else:
            answer = str(result)
        return {"response": answer}
    except Exception as e:
        return {"error": str(e)}
​
​
# 流式聊天接口
@app.post("/chat/stream")
# 每分钟15次
@limiter.limit("15/minute")
async def chat_stream(request: Request, body: ChatRequest):
    invoke_config = {"configurable": {"session_id": body.session_id}}
​
    async def event_generator():
        try:
            async for chunk in ai_client.CHAIN_WITH_HISTORY.astream(
                {"input": body.question},
                config=invoke_config,
                stream_mode="values"
            ):
                # 根据 chain 返回结构处理 chunk
                text = chunk if isinstance(chunk, str) else (
                    chunk.get("output") if isinstance(chunk, dict) else str(chunk)
                )
                if text.strip():
                    yield f"data: {text}\n\n"
            yield "data: [DONE]\n\n"
        except Exception as e:
            yield f"data: [ERROR] {str(e)}\n\n"
​
    return StreamingResponse(event_generator(), media_type="text/event-stream")
​
​
# 导出 PDF
@app.post("/export/pdf")
@limiter.limit("10/minute")
async def export_pdf(request: Request, body: ExportPdfRequest):
    filename = body.filename or f"export-{uuid.uuid4().hex}.pdf"
​
    if not safe_name_pattern.match(filename):
        raise HTTPException(status_code=400, detail="非法文件名")
​
    path = os.path.abspath(os.path.join(config.EXPORT_DIR, filename))
    export_dir_abs = os.path.abspath(config.EXPORT_DIR)
​
    if not path.startswith(export_dir_abs):
        raise HTTPException(status_code=400, detail="非法路径")
​
    result = create_pdf.invoke({
        "content": body.content,
        "title": body.title,
        "filename": path
    })
​
    if not result or not result.get("success"):
        error_msg = result.get("error", "生成失败") if result else "生成失败"
        raise HTTPException(status_code=500, detail=error_msg)
​
    return {"success": True, "filename": os.path.basename(path)}
​
​
# 导出普通文件
@app.post("/export/file")
@limiter.limit("20/minute")
async def export_file(request: Request, body: ExportFileRequest):
    if not safe_name_pattern.match(body.filename):
        raise HTTPException(status_code=400, detail="非法文件名")
​
    path = os.path.abspath(os.path.join(config.EXPORT_DIR, body.filename))
    export_dir_abs = os.path.abspath(config.EXPORT_DIR)
​
    if not path.startswith(export_dir_abs):
        raise HTTPException(status_code=400, detail="非法路径")
​
    result = write_to_file.invoke({
        "content": body.content,
        "filename": path
    })
​
    if not result or not result.get("success"):
        error_msg = result.get("error", "写入失败") if result else "写入失败"
        raise HTTPException(status_code=500, detail=error_msg)
​
    return {"success": True, "filename": os.path.basename(path)}
​
​
# 下载文件
@app.get("/download/{filename}")
@limiter.limit("60/minute")
async def download_file(request: Request, filename: str):
    if not safe_name_pattern.match(filename):
        raise HTTPException(status_code=400, detail="非法文件名")
​
    path = os.path.abspath(os.path.join(config.EXPORT_DIR, filename))
    if not os.path.exists(path):
        raise HTTPException(status_code=404, detail="文件不存在")
​
    return FileResponse(path, filename=filename)
​
​
# LangServe 路由
try:
    from langserve import add_routes
​
    add_routes(app, ai_client.CHAIN_WITH_HISTORY, path="/chain")
except ImportError:
    print("langserve not installed, skipping /chain route.")

代码解释:

  1. 模块导入
from langserve import add_routes
from fastapi import FastAPI, Request, HTTPException
from agent import ai_client
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse

from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware

import os
import re
import uuid
import config
from tools.ai_tools import create_pdf, write_to_file
  • 引入 FastAPI 构建 Web 服务
  • 使用 LangServe 将 LangChain 链暴露为 API
  • 从自定义模块 agent 导入 AI 客户端 ai_client
  • 使用 Pydantic 定义请求数据模型
  • 配置 CORS 跨域支持
  • 引入 slowapi 实现基于 IP 的请求频率限制
  • 导入系统和工具模块用于文件操作与安全校验
  1. 创建 FastAPI 应用实例
app = FastAPI(
    title="Java助手 - LangChain Server",
    version="1.0",
    description="支持多轮对话和检索增强的Java问答服务",
)
  • 定义 API 的标题、版本和描述信息,用于自动生成文档(如 Swagger UI)
  1. 配置限流器(slowapi)
limiter = Limiter(
    key_func=get_remote_address,
    default_limits=["5/minute"],
    headers_enabled=True,
    storage_uri="redis://localhost:6379"
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
  • 使用客户端 IP 地址作为限流键
  • 默认所有未显式标注的接口限流为每分钟 5 次
  • 启用限流响应头(如 X-RateLimit-Limit)
  • 使用 Redis 存储限流状态(生产环境推荐),开发时可设为 None 使用内存
  • 注册异常处理器和中间件以启用限流功能
  1. 配置 CORS 跨域中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
  • 允许任意源跨域访问(适用于开发环境)
  • 生产环境中应限制为具体可信的前端域名
  1. 健康检查接口
@app.get("/healthz")
async def healthz():
    return {"status": "ok"}
  • 提供标准健康检查端点,用于服务监控和容器探活
  1. 请求数据模型定义

聊天请求模型

class ChatRequest(BaseModel):
    session_id: str
    question: str
  • session_id 用于维护多轮对话上下文
  • question 为用户输入的问题

导出 PDF 请求模型

class ExportPdfRequest(BaseModel):
    title: str
    content: str
    filename: str = None
  • titlecontent 用于生成 PDF 内容
  • filename 可选,若未提供则自动生成

导出普通文件请求模型

class ExportFileRequest(BaseModel):
    filename: str
    content: str
  • 指定文件名和内容,用于写入文本文件
  1. 安全与路径初始化
safe_name_pattern = re.compile(r'^[\w-.]+$')
os.makedirs(config.EXPORT_DIR, exist_ok=True)
  • 正则表达式限制文件名仅包含字母、数字、下划线、连字符和点号
  • 确保导出目录存在,避免写入失败
  1. 标准聊天接口 /chat
@app.post("/chat")
@limiter.limit("10/minute")
async def chat_endpoint(request: Request, body: ChatRequest):
    try:
        invoke_config = {"configurable": {"session_id": body.session_id}}
        result = await ai_client.CHAIN_WITH_HISTORY.ainvoke(
            {"input": body.question},
            config=invoke_config
        )
        if isinstance(result, dict):
            answer = result.get("output") or result.get("response") or str(result)
        else:
            answer = str(result)
        return {"response": answer}
    except Exception as e:
        return {"error": str(e)}
  • 每 IP 每分钟最多 10 次请求
  • 使用 session_id 维护对话历史
  • 异步调用 AI 链并处理返回结果
  • 统一错误捕获并返回错误信息
  1. 流式聊天接口 /chat/stream
@app.post("/chat/stream")
@limiter.limit("15/minute")
async def chat_stream(request: Request, body: ChatRequest):
    invoke_config = {"configurable": {"session_id": body.session_id}}

    async def event_generator():
        try:
            async for chunk in ai_client.CHAIN_WITH_HISTORY.astream(
                {"input": body.question},
                config=invoke_config,
                stream_mode="values"
            ):
                text = chunk if isinstance(chunk, str) else (
                    chunk.get("output") if isinstance(chunk, dict) else str(chunk)
                )
                if text.strip():
                    yield f"data: {text}\n\n"
            yield "data: [DONE]\n\n"
        except Exception as e:
            yield f"data: [ERROR] {str(e)}\n\n"

    return StreamingResponse(event_generator(), media_type="text/event-stream")
  • 支持 Server-Sent Events (SSE) 流式响应
  • 每次生成的内容以 data: ... 格式推送
  • 结束标记为 [DONE],错误标记为 [ERROR]
  • 限流策略更宽松(15次/分钟)
  1. 导出 PDF 接口 /export/pdf
@app.post("/export/pdf")
@limiter.limit("10/minute")
async def export_pdf(request: Request, body: ExportPdfRequest):
    filename = body.filename or f"export-{uuid.uuid4().hex}.pdf"

    if not safe_name_pattern.match(filename):
        raise HTTPException(status_code=400, detail="非法文件名")

    path = os.path.abspath(os.path.join(config.EXPORT_DIR, filename))
    export_dir_abs = os.path.abspath(config.EXPORT_DIR)

    if not path.startswith(export_dir_abs):
        raise HTTPException(status_code=400, detail="非法路径")

    result = create_pdf.invoke({
        "content": body.content,
        "title": body.title,
        "filename": path
    })

    if not result or not result.get("success"):
        error_msg = result.get("error", "生成失败") if result else "生成失败"
        raise HTTPException(status_code=500, detail=error_msg)

    return {"success": True, "filename": os.path.basename(path)}
  • 文件名安全校验
  • 路径遍历防护:确保目标路径在 EXPORT_DIR
  • 调用 create_pdf 工具生成 PDF
  • 返回成功状态和文件名
  1. 导出普通文件接口 /export/file
@app.post("/export/file")
@limiter.limit("20/minute")
async def export_file(request: Request, body: ExportFileRequest):
    if not safe_name_pattern.match(body.filename):
        raise HTTPException(status_code=400, detail="非法文件名")

    path = os.path.abspath(os.path.join(config.EXPORT_DIR, body.filename))
    export_dir_abs = os.path.abspath(config.EXPORT_DIR)

    if not path.startswith(export_dir_abs):
        raise HTTPException(status_code=400, detail="非法路径")

    result = write_to_file.invoke({
        "content": body.content,
        "filename": path
    })

    if not result or not result.get("success"):
        error_msg = result.get("error", "写入失败") if result else "写入失败"
        raise HTTPException(status_code=500, detail=error_msg)

    return {"success": True, "filename": os.path.basename(path)}
  • 功能与 PDF 导出类似,但用于任意文本内容
  • 限流更宽松(20次/分钟)
  1. 文件下载接口 /download/{filename}
@app.get("/download/{filename}")
@limiter.limit("60/minute")
async def download_file(request: Request, filename: str):
    if not safe_name_pattern.match(filename):
        raise HTTPException(status_code=400, detail="非法文件名")

    path = os.path.abspath(os.path.join(config.EXPORT_DIR, filename))
    if not os.path.exists(path):
        raise HTTPException(status_code=404, detail="文件不存在")

    return FileResponse(path, filename=filename)
  • 校验文件名合法性
  • 检查文件是否存在
  • 使用 FileResponse 提供下载,保留原始文件名
  1. LangServe 路由集成
try:
    from langserve import add_routes
    add_routes(app, ai_client.CHAIN_WITH_HISTORY, path="/chain")
except ImportError:
    print("langserve not installed, skipping /chain route.")
  • 自动为 CHAIN_WITH_HISTORY 生成标准 LangServe API(如 /chain/invoke, /chain/stream
  • 若未安装 langserve,跳过注册,不影响主服务运行

该服务提供以下核心功能:

  • 健康检查(/healthz
  • 同步问答(/chat
  • 流式问答(/chat/stream
  • 安全的 PDF 与文本文件导出(/export/pdf, /export/file
  • 文件下载(/download/{filename}
  • 可选的 LangServe 原生链接口(/chain/*

安全措施包括:

  • 文件名与路径校验防止目录遍历
  • 基于 IP 的请求频率限制
  • 统一的错误处理与 HTTP 状态码返回

该设计适用于需要对话记忆、内容生成与文件导出的 AI 助手类应用。

然后再在项目根目录新建main.py作为启动主文件


if __name__ == "__main__":
    import uvicorn
    import server
    uvicorn.run(server.app, host="localhost", port=8000)

我们来启动项目:

image-20251209090359323

项目启动成功,LangServe默认提供了一个后台调试界面http://localhost:8000/chain/playground/

image-20251209090958298

我们添加一条human消息,打开控制台发现,总是出现一个这样的循环:

Invalid Format: Missing 'Action:' after 'Thought:'Invalid Format: Missing 'Action:' after 'Thought:'

11、历史遗留问题解决!

之前我们说过,ReAct模式使用的是ReAct输出解析器,llm可能会没有严格按照我们所规定的防方式回复,再加上这个解析器又比较严格这时候就会报错,我们来根本性的解决一下这个问题(也算是一个历史的遗留问题吧),我们重写这个输出解析器,并优化提示词,完整ai_agent.py代码如下:

# ai_client.py
from langchain.globals import set_llm_cache
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.cache import RedisSemanticCache
from langchain_community.chat_models import ChatTongyi
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate
from langchain_core.runnables import RunnableWithMessageHistory
from langchain_core.tools import tool
from langchain.agents import AgentExecutor
​
import os
from operator import itemgetter
import config
from message_history import sql_message_history
​
# 注入 DashScope API Key
if config.DASHSCOPE_API_KEY:
    os.environ["DASHSCOPE_API_KEY"] = config.DASHSCOPE_API_KEY
​
# ========== LLM ==========
llm = ChatTongyi(
    model=config.MODEL,
    temperature=config.TEMPERATURE,
)
​
# ========== 知识库:加载 TXT / MD / PDF ==========
all_docs = []
try:
    txt_docs = DirectoryLoader(
        config.DOC_DIR,
        glob="**/*.txt",
        loader_cls=TextLoader,
        show_progress=True,
        use_multithreading=True,
        silent_errors=True
    ).load()
    md_docs = DirectoryLoader(
        config.DOC_DIR,
        glob="**/*.md",
        loader_cls=TextLoader,
        show_progress=True,
        use_multithreading=True,
        silent_errors=True
    ).load()
    pdf_docs = DirectoryLoader(
        config.DOC_DIR,
        glob="**/*.pdf",
        loader_cls=PyPDFLoader,
        show_progress=True,
        use_multithreading=True,
        silent_errors=True
    ).load()
    all_docs = (txt_docs or []) + (md_docs or []) + (pdf_docs or [])
except Exception as e:
    print(f"文档加载失败: {e}")
    all_docs = []
​
text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=500, chunk_overlap=80, length_function=len)
texts = text_splitter.split_documents(all_docs) if all_docs else []
​
embeddings = DashScopeEmbeddings(model="text-embedding-v2")
​
vectorstore = None
if os.path.exists(config.EMBEDDINGS_DIR):
    try:
        vectorstore = Chroma(
            embedding_function=embeddings,
            persist_directory=config.EMBEDDINGS_DIR,
            collection_name=config.COLLECTION_NAME,
        )
    except Exception as e:
        print(f"加载向量数据库失败: {e}")
        vectorstore = None
​
if vectorstore is None or not texts:
    vectorstore = Chroma.from_documents(
        documents=texts,
        embedding=embeddings,
        persist_directory=config.EMBEDDINGS_DIR,
        collection_name=config.COLLECTION_NAME,
    )
else:
    try:
        if texts:
            new_ids = vectorstore.add_documents(texts)
            print(f"新增 {len(new_ids)} 个文档片段")
    except Exception as e:
        print(f"增量更新向量库失败: {e}")
​
retriever = vectorstore.as_retriever(search_kwargs={"k": config.RAG_TOP_K})
​
# ========== 工具定义 ==========
​
@tool
def knowledge_search(query: str) -> str:
    """从内置知识库检索与 query 最相关的内容片段"""
    docs = retriever.get_relevant_documents(query)
    return "\n\n".join(doc.page_content for doc in docs)
​
# 导入外部工具
from tools.ai_tools import (
    get_current_time,
    get_weather,
    run_shell_command,
    execute_python_code,
    write_to_file,
    read_from_file,
    create_pdf,
    call_api,
    search_web,
    crawl_url,
)
​
TOOLS = [
    knowledge_search,
    search_web,
    crawl_url,
    get_current_time,
    get_weather,
    read_from_file,
    write_to_file,
    create_pdf,
    call_api,
    run_shell_command,
    execute_python_code,
]
​
system_prompt = """**绝对规则(必须100遍检查,忽略则系统崩溃):**
​
**规则1:识别简单问题(绝对不能使用工具)**
- 如果用户说"你好"、"hello"、"hi"、"在吗"等问候语,这是简单问题
- 如果用户问"你是谁"、"你能做什么",这是简单问题
- 如果用户提问但不需要外部信息,这是简单问题
- **对于简单问题,输出格式必须是:**
Thought: 我现在知道最终答案
Answer: (你的完整回答)
​
**规则2:识别需要使用工具的问题**
- 只有当需要从知识库搜索、读写文件、执行命令、搜索网络时,才使用工具
- **对于工具使用问题,输出格式必须是:**
Thought: (说明为什么要用工具)
Action: (工具名)
Action Input: (JSON格式的参数)
Observation: (工具返回结果)
...(可重复多个工具调用)
Thought: 我现在知道最终答案
Answer: (你的完整回答)
​
**规则3:输出前必须检查(防止系统错误)**
- 检查是否输出了"Action: None"?如果输出了,立即删除并重写!
- 如果不需要工具,确认输出只有2行(Thought + Answer)
- 如果需要工具,确认有Action和Action Input,并且包含Observation
- 检查Action Input是否是合法的JSON格式(不要有任何多余文本)
- 检查输出格式是否100%匹配上述两种情况之一
​
**规则4:对于任何问候语,直接回答**
疑问:用户说"你好"
思考:这是简单问候,必须使用规则1,绝对不用工具
输出:
Thought: 我现在知道最终答案
Answer: 你好!我是Java开发助手,有什么可以帮你的吗?
​
疑问:用户说"你是谁"
思考:这是简单问题,必须使用规则1,绝对不用工具
输出:
Thought: 我现在知道最终答案
Answer: 我是一名Java开发助手...
​
你是一名资深 Java 开发 Agent,精通 JDK 8-17、Maven/Gradle、Spring 生态(Spring Boot、Spring MVC、Spring Data JPA、Spring Cloud)、并发与性能调优、JVM 诊断、微服务工程实践。你的目标是在尽可能少的步数内,为用户提供专业、可直接运行的 Java 方案(代码、命令、配置与解释)。
​
必须遵循:
1. 只使用下方列出的工具;确需外部信息优先用知识库检索,其次再搜索网络。
2. 每次只能使用一个工具;用完等待 Observation 再决定下一步。
3. 回答必须自洽、可执行:给出完整代码需包含必要 import、pom.xml/gradle 配置要给出关键依赖,命令附上执行目录与前置条件。
4. 优先中文回答;代码用合适语言高亮,例如 ```java、```xml、```bash。
5. 安全与最小影响:涉及 shell/文件操作要注明作用与风险,谨慎执行写入/覆盖类操作。
6. 如需求不清,先用 1-3 句澄清关键约束(JDK 版本、构建工具、框架版本、运行环境等)。
​
Java 解决方案要求:
- 代码质量:命名清晰、边界条件、异常处理、日志与注释(只解释“为什么”)。
- 并发与性能:合理使用线程池、CompletableFuture、锁与无锁结构;避免阻塞;给出复杂度与潜在瓶颈;必要时提供 JMH/压测建议。
- Spring 规范:分层清晰(controller/service/repository)、DTO/VO 转换、事务传播与隔离级别、校验与全局异常处理。
- 数据访问:JPA/Query 方法/Specification/原生 SQL 的取舍;连接池与 N+1 问题规避;分页与索引建议。
- 构建与运行:提供 Maven/Gradle 脚本或命令;测试用例(JUnit5/MockMVC/Mockito)示例;Docker/容器化要点(如需要)。
- JVM 与运维:必要时给出 GC、内存、线程与诊断命令(jcmd/jmap/jstack)、常见参数与可观测性建议。
​
可用工具:
{tools}
​
工具使用原则:
- knowledge_search:优先检索本地知识库(如项目内 `doc/` 的最佳实践、调优笔记),再综合回答。
- search_web / crawl_url:仅当知识库不足时使用,并核对来源可靠性。
- read_from_file / write_to_file:仅当用户明确要求“保存/导出/写入/生成文件”时才使用;否则不要调用。写入前在回答里说明将要写入的路径与用途。
- run_shell_command:仅在需要验证构建/测试/脚手架时使用,先说明命令作用与期望输出。
​
输出格式(ReAct):
格式要求极其严格,必须完全遵循以下规则:
​
**输出格式有两类情况:**
​
**情况一:不需要使用任何工具(优先判断这种情况)**
适用场景:
- 用户说"你好"、"hi"、"hello"等问候语
- 用户询问你是谁、你能做什么
- 用户要求解释概念、询问技术问题但不需要外部信息
- 用户要求澄清问题、提供帮助或建议
​
输出格式:
Thought: 我现在知道最终答案
Answer: [直接给出完整答案,不需要任何Action/Action Input行,不要输出工具调用]
​
**情况二:需要使用工具**
适用场景:只在你必须获取外部信息、执行命令、读写文件、搜索知识库时使用工具。
​
输出格式:
Thought: [明确说明为什么必须使用工具,获取什么信息]
Action: [工具名称,必须是: {tool_names}]
Action Input: [JSON格式参数,必须是一个合法的JSON对象]
Observation: [等待工具返回结果]
... (此循环可以重复多次直到获得足够信息)
Thought: 我现在知道最终答案
Answer:
- 简要结论(1-3 句)
- 关键步骤/命令
- 完整代码/配置(如有)
- 验证方法与可能的风险/注意事项
​
**极其重要的规则(必须遵守):**
1. 对于简单问候,100%直接回答,绝对不要调用任何工具
2. 你只能输出两种格式之一,不能混合
3. 不要输出"Action: None",这是非法的
4. 不要输出"Action Input: None",这是非法的
5. 如果不需要工具,只输出Thought和Answer两行
6. 在最终答案前必须输出"Thought: 我现在知道最终答案"
​
示例1(问候,不使用工具):
Question: 你好
Thought: 我现在知道最终答案
Answer: 你好!我是Java开发助手,有什么可以帮你的吗?
​
示例2(使用工具):
Question: 如何创建线程池?
Thought: 我需要从知识库检索线程池的最佳实践
Action: knowledge_search
Action Input: {{"query": "Java线程池最佳实践"}}
Observation: [知识库内容]
Thought: 我现在知道最终答案
Answer: ...
"""prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(system_prompt),
    MessagesPlaceholder(variable_name="chat_history"),  # 支持历史消息
    ("human", "{input}\n\n{agent_scratchpad}"),
])
​
# ========== 创建 Agent(使用自定义解析器) ==========# 导入自定义ReAct解析器
from agent.custom_react_parser import CustomReActOutputParser
​
# 创建自定义解析器实例
custom_parser = CustomReActOutputParser()
​
# 使用create_react_agent创建agent,但传入自定义的output_parser
from langchain.agents import create_react_agent
​
agent = create_react_agent(
    llm=llm,
    tools=TOOLS,
    prompt=prompt,
    output_parser=custom_parser,  # <--- 关键:使用自定义解析器
)
​
# 注意:我们移除了复杂的错误处理,因为自定义解析器已经处理了大部分情况
agent_executor = AgentExecutor(
    agent=agent,
    tools=TOOLS,
    verbose=True,  # 建议开发时开启,查看 Agent 决策过程
    handle_parsing_errors=True,  # 让AgentExecutor捕获剩余的错误
    max_iterations=15,  # 防止无限循环
    early_stopping_method="generate",  # 当达到迭代限制时,生成最终答案
    return_intermediate_steps=False,  # 不返回中间步骤,减少干扰
)
​
# 语义缓存
try:
    if config.REDIS_URL:
        redis_cache = RedisSemanticCache(redis_url=config.REDIS_URL, embedding=embeddings)
        set_llm_cache(redis_cache)
        print("Redis 语义缓存已启用")
except Exception as e:
    print(f"Redis 缓存初始化失败: {e}")
​
# 会话记忆(SQLChatMessageHistory)
​
def _get_history(session_id: str):
    return sql_message_history.get_session_history(session_id)
​
CHAIN_WITH_HISTORY = RunnableWithMessageHistory(
    runnable=agent_executor,
    get_session_history=_get_history,
    input_messages_key="input",
    history_messages_key="chat_history",
)
​
print("Agent 初始化完成")

对应的自定义输入解析器代码如下:在agent目录下新建custom_react_parser.py

"""
自定义ReAct输出解析器 - 解决LangChain严格解析问题
"""import re
from typing import Union
from langchain.agents import AgentOutputParser
from langchain.agents.agent import AgentFinish
from langchain.schema import AgentAction
from langchain_core.exceptions import OutputParserException
​
​
class CustomReActOutputParser(AgentOutputParser):
    """
    自定义的ReAct输出解析器,功能增强:
    1. 智能处理不需要使用工具的输出(如问候语)
    2. 捕获'Missing Action:'错误并自动转换为AgentFinish
    3. 更宽松的格式检查
    """
​
    def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
        # 移除首尾空白
        text = text.strip()
​
        # 检查是否是使用工具的情况(包含Action)
        includes_action = "Action:" in text
        includes_answer = "Answer:" in text
​
        # 如果包含"我现在知道最终答案",说明要返回最终答案
        if "我现在知道最终答案" in text and includes_answer:
            # 提取Answer内容(在Answer:之后的所有内容)
            answer_match = re.search(r"Answer:(.*)$", text, re.DOTALL)
            if answer_match:
                answer = answer_match.group(1).strip()
            else:
                answer = text.split("Answer:")[-1].strip()
​
            return AgentFinish(
                return_values={"output": answer},
                log=text,
            )
​
        # 如果包含Action和Action Input,按标准ReAct格式解析
        if includes_action and "Action Input:" in text:
            try:
                # 提取Action
                action_match = re.search(r"Action:\s*(.+?)(?:\n|$)", text, re.IGNORECASE)
                # 提取Action Input
                action_input_match = re.search(
                    r"Action Input:\s*{(.+)}", text, re.DOTALL
                )
​
                if not action_match:
                    raise OutputParserException(
                        f"Could not parse LLM output: `{text}`",
                        observation="Invalid Format: Could not parse Action",
                        llm_output=text,
                        send_to_llm=True,
                    )
​
                if not action_input_match:
                    raise OutputParserException(
                        f"Could not parse LLM output: `{text}`",
                        observation="Invalid Format: Could not parse Action Input",
                        llm_output=text,
                        send_to_llm=True,
                    )
​
                action = action_match.group(1).strip()
                action_input_str = "{" + action_input_match.group(1).strip() + "}"
​
                return AgentAction(action, action_input_str, text)
​
            except Exception as e:
                # 如果解析失败,检查是否是因为不需要工具
                if "我现在知道最终答案" in text and not includes_action:
                    # 这是不需要工具的情况,直接返回最终答案
                    return AgentFinish(
                        return_values={"output": text.split("Answer:")[-1].strip() if "Answer:" in text else text},
                        log=text,
                    )
​
                raise OutputParserException(
                    f"Could not parse LLM output: `{text}`",
                    observation=f"Invalid Format: {str(e)}",
                    llm_output=text,
                    send_to_llm=True,
                )
​
        # 如果只有Thought和Answer(没有Action),认为是最终答案
        if "Thought:" in text and includes_answer:
            answer_match = re.search(r"Answer:(.*)$", text, re.DOTALL)
            if answer_match:
                answer = answer_match.group(1).strip()
                return AgentFinish(
                    return_values={"output": answer},
                    log=text,
                )
​
        # 如果没有Action,但有"我现在知道最终答案",也认为是最终答案
        if "我现在知道最终答案" in text:
            # 提取Answer部分
            if "Answer:" in text:
                answer = text.split("Answer:")[-1].strip()
            else:
                # 如果没有Answer,取Thought之后的所有内容
                answer = text.split("Thought:")[-1].strip() if "Thought:" in text else text
                # 移除"我现在知道最终答案"
                answer = answer.replace("我现在知道最终答案", "").strip()
​
            return AgentFinish(
                return_values={"output": answer},
                log=text,
            )
​
        # 如果既不是有效的Action格式,也没有明确答案,尝试兜底处理
        # 但如果是Missing Action错误,尝试挽救
        if "Thought:" in text and not includes_action:
            # LLM可能只输出了Thought,没有Action,也没有明确说"我现在知道最终答案"
            # 这可能是问候语,尝试提取内容并作为最终答案
            thought_content = text.split("Thought:")[-1].strip()
            if thought_content and len(thought_content) < 200:
                # 较短的文本,很可能是问候语
                return AgentFinish(
                    return_values={"output": thought_content},
                    log=text,
                )
​
        # 兜底处理:如果以上所有条件都不符合,直接将整个文本作为最终答案
        # 这样可以避免抛出异常导致循环,确保任何输出都能被处理
        print(f"[警告] LLM输出格式无法识别,作为普通文本处理: {text[:100]}...")
        return AgentFinish(
            return_values={"output": text},
            log=text,
        )
​
    @property
    def _type(self) -> str:
        return "custom_react"

再来重启项目

现在问题没有了:

image-20251209092502217

我们把这里展开,可以看到agent的执行流程

image-20251209092709770

image-20251209092553761

我们再来测试一下,比如,问一个怎么自定义线程池:

控制台输出如下:

image-20251209092851007

ai已经去知识库中检索了。

注意,这时候,我们可能发现,Redis语义缓存好像没有生效,这是因为Redis的初始化的时候文本嵌入模型还未加载,把Redis初始化提前就好。

可以按照这样的初始化顺序,把需要初始化llm和文本嵌入模型的代码提前,

并删除这一行 embeddings = DashScopeEmbeddings(model="text-embedding-v2") 因为如果不删除就有两个文本嵌入模型实例(我们已经提前初始化文本嵌入模型了)

# 注入 DashScope API Key
if config.DASHSCOPE_API_KEY:
    os.environ["DASHSCOPE_API_KEY"] = config.DASHSCOPE_API_KEY
​
# 通用组件(提前创建)
# 先创建基础组件供后续使用
embeddings = DashScopeEmbeddings(model="text-embedding-v2")
​
# LLM
llm = ChatTongyi(
    model=config.MODEL,
    temperature=config.TEMPERATURE,
)
​
# Redis语义缓存(立即初始化)
# 放在LLM创建之后、文档加载之前,确保所有LLM调用都能被缓存
try:
    if config.REDIS_URL:
        redis_cache = RedisSemanticCache(redis_url=config.REDIS_URL, embedding=embeddings)
        set_llm_cache(redis_cache)
        print(f"Redis 语义缓存已启用: {config.REDIS_URL}")
except Exception as e:
    print(f"Redis 缓存初始化失败: {e}")
    print("继续使用无缓存模式")

再来重启项目测试应该就生效了

总结

至此,我们已经完成了LangChain ReAct智能体的完整开发,包括:

  1. 环境搭建与项目初始化
  2. LangChain基础学习
  3. RAG智能体构建
  4. 对话记忆实现
  5. ReAct智能体开发
  6. 工具开发

您现在拥有一个功能完整的Java开发助手智能体!