基于LangChain的RAG应用开发(05)-基于SQL查询数据的RAG应用

3 阅读5分钟

摘要

之前的RAG应用都是从向量库等存储组件中获取相关内容,并生成回答。现在需要实现,从结构化数据库中获取数据,生成相应的回答,两者具有本质的区别。

SQL-RAG实现的大致流程如下:

  • 将用户问题query转换为特定语言的SQL,该步骤一般由大模型完成;
  • 执行生成的SQL查询语句
  • 模型根据查询结果和用户输入query生成回答

数据库初始化

将Chinook_Sqlite.sql保存到项目路径下,在命令行终端执行sqlite3 Chinook.db.read Chinook.sql命令创建数据库。然后执行如下代码检查是否创建成功:

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect) # 数据库方言
print(db.get_usable_table_names()) # 数据库中的表名
db.run("SELECT * FROM Artist LIMIT 10;") 

大模型初始化

代码如下:

from langchain_openai import ChatOpenAI

from dotenv import load_dotenv
import os
load_dotenv()

api_key = os.getenv("api_key")
base_url = os.getenv("base_url")

model = ChatOpenAI(model = "deepseek-chat",api_key=api_key, base_url=base_url)

基于链实现SQL-RAG应用

基本步骤:NL转SQL——》执行SQL——》根据结果生成答案

根据用户query生成SQL

使用create_sql_query_chain生成SQL的生成链

代码如下:

from langchain.chains.sql_database.query import create_sql_query_chain

query_sql_chain = create_sql_query_chain(model, db) # 使用内置链构建query-SQL转换链
response = query_sql_chain.invoke({"question": "How many employees are there"})
response

代码解释:

create_sql_query_chain是LangChain关于NL2SQL的内置链,功能为:将用户Query转换为SQL。具体代码如下:

def create_sql_query_chain(
    llm: BaseLanguageModel, # 所使用的大模型
    db: SQLDatabase, # 数据库
    prompt: Optional[BasePromptTemplate] = None, # 提示词模板
    k: int = 5, # Select语句返回的结果数量
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
    """
    Prompt:
    	如果没有传入参数没有提供提示词模板,则会根据数据库的方言选择一个默认的提示词模板。如果需要自定义提示词模板,则需要提供如下参数:
    	{input}:这里面填写用户输入的query,并在query后添加"\nSQLQuery"后缀;
    	{table_info}:这里面填写数据库的表定义和列定义,一般用db.get_table_info获取;
    	{top_k}:这里填写每个SELECT语句返回的结果数;
    	{dialect}:这里填写数据库方言(可选)
    """
    if prompt is not None: # 如果是自定义prompt,使用自定义prompt
        prompt_to_use = prompt
    elif db.dialect in SQL_PROMPTS: # 如果不是自定义prompt,先判断数据库方言,根据方言获取对应的默认提示词模板
        prompt_to_use = SQL_PROMPTS[db.dialect]
    else: # 如果不是SQL_PROMPT定义中存在方言,则使用默认prompt
        prompt_to_use = PROMPT
    if {"input", "top_k", "table_info"}.difference(
        prompt_to_use.input_variables + list(prompt_to_use.partial_variables)
    ): # 判断{input}、{top_k}、{table_info}在prompt_to_use中是否存在,主要针对自定义prompt
        raise ValueError(
            f"Prompt must have input variables: 'input', 'top_k', "
            f"'table_info'. Received prompt with input variables: "
            f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}"
        )
    if "dialect" in prompt_to_use.input_variables:
        prompt_to_use = prompt_to_use.partial(dialect=db.dialect) # 动态插入数据库方言参数

    inputs = {
        "input": lambda x: x["question"] + "\nSQLQuery: ", # 这就是需要在query后增加"\nSQLQuery"的原因
        "table_info": lambda x: db.get_table_info(
            table_names=x.get("table_names_to_use")
        ),
    }
    return (
        RunnablePassthrough.assign(**inputs)  # inputs输入
        | (
            lambda x: { # 将“question”和“table_names_to_use”之外的键值对传入
                k: v
                for k, v in x.items()
                if k not in ("question", "table_names_to_use")
            }
        )
        | prompt_to_use.partial(top_k=str(k))
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
        | _strip
    )

执行SQL

使用QuerySQLDataBaseTool创建SQL执行链

代码如下:

from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_core.runnables import RunnableLambda

execute_query = QuerySQLDatabaseTool(db = db) # SQL执行链
write_query = create_sql_query_chain(model,db) # SQL生成链

def parser(sql_str): # 这个函数是为了将生成链生成的SQL去掉前缀,应该和使用的模型有关
    return sql_str.split('SQLQuery: ')[1]

# chain_1 = write_query | RunnableLambda(lambda x : x.split('SQLQuery: ')[1])| execute_query

chain = write_query | parser | execute_query # 这个链执行输出SQL的查询结果


chain.invoke({"question": "How many employees are there"})

代码解释:

QuerySQLDatabaseTool类,继承自BaseSQLDatabaseTool,是一个执行SQL查询工具类。需要传入参数:db数据库。执行SQL语句,核心代码是:db.run_no_throw(query)。

回答问题

只需将原始问题和 SQL 查询结果结合起来生成最终答案。

代码如下:

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

from operator import itemgetter

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

chain = (
    RunnablePassthrough.assign(query=write_query).assign(result=itemgetter("query") | RunnableLambda(parser) |execute_query)
    | answer_prompt
    | model
    | StrOutputParser()
)
"""
chain_test_1 = (
    RunnablePassthrough.assign(query = write_query | RunnableLambda(parser))
    | RunnablePassthrough.assign(result=itemgetter("query") | execute_query)
    | answer_prompt
    | model
    | StrOutputParser()
)

chain_test_2 = (
    {"query": write_query | RunnableLambda(parser),"question":itemgetter("question")}
    | RunnablePassthrough.assign(result=itemgetter("query") | execute_query)
    | answer_prompt
    | model
    | StrOutputParser()
)
"""

chain.invoke({"question": "How many employees are there"})

基于Agent实现SQL-RAG应用

基于Agent实现SQL-RAG时,需要使用一系列工具tools以供Agent使用,LangChain提供了一个相应的工具类SQLDatabaseToolkit,无需我们自己定义。之后创建Agent后,即可Agent自己调用。具体代码如下:

from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db = db, llm = model)
tools = toolkit.get_tools()

for i in tools:
    print(f"工具名称:{i.name},工具描述:{i.description}")

工具集tools中包含:

sql_db_query:执行SQL语句的工具,如果执行报错会返回错误信息。

sql_db_schema:根据输入的表名称,输出这些表的详细信息以及这些表的列信息。

sql_db_list_tables:输出数据库的表名称。

sql_db_query_checker:检查SQL语句是否正确。

from langchain_core.messages import SystemMessage

SQL_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

To start you should ALWAYS look at the tables in the database to see what you can query.
Do NOT skip this step.
Then you should query the schema of the most relevant tables."""

system_message = SystemMessage(content=SQL_PREFIX)

from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(model, tools, prompt=system_message)

for s in agent_executor.stream(
    {"messages": [HumanMessage(content="Which country's customers spent the most?")]}
):
    print(s)
    print("----")

创建Prompt,并使用create_react_agent创建一个Agent。

原文地址:https://www.cnblogs.com/AfroNicky/p/18916667