构建自定义AI客户支持助手

55 阅读18分钟

自定义AI客户支持助手可以通过处理常规问题来释放团队时间,但构建一个能够可靠处理各种任务而不会让用户抓狂的智能助手可能很困难。

在我们这一章节中,我们将为一家航空公司构建自定义AI客户支持助手,帮助用户研究和安排旅行。我们将使用LangGraph的中断和检查点以及更复杂的状态来组织助手的工具并管理用户的航班预订、酒店预订、汽车租赁和游览。本章节需要我们熟悉前面的LangGraph的基本概念和用法。

完成后,我们将构建出一个可工作的智能助手,并了解LangGraph的关键概念和架构。我们将能够将这些设计模式应用到其他AI项目中。

自定义AI助手的架构将如下图所示: 在这里插入图片描述

准备工作 首先设置环境。我们将安装本教程的先决条件,下载测试数据库,并定义在每个部分中重复使用的工具。

我们将使用ChatGPT作为LLM,并定义多个自定义工具。虽然大多数工具将连接到本地sqlite数据库(不需要额外依赖),我们还将使用Tavily为代理提供通用网页搜索。

%%capture --no-stderr
%pip install -U langgraph langchain-community langchain-anthropic tavily-python pandas openai

import getpass
import os

def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("ANTHROPIC_API_KEY")
_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")

填充数据库

运行下一个python文件来获取我们为本教程准备的sqlite数据库,并更新它使其看起来是最新的。具体文件如下:

import os
import shutil
import sqlite3
import pandas as pd
import requests

db_url = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
local_file = "travel2.sqlite"
# 备份文件让我们可以在每个教程部分重新开始
backup_file = "travel2.backup.sqlite"
overwrite = False

if overwrite or not os.path.exists(local_file):
    response = requests.get(db_url)
    response.raise_for_status()  # 确保请求成功
    with open(local_file, "wb") as f:
        f.write(response.content)
    # 备份 - 我们将使用这个来"重置"每个部分的数据库
    shutil.copy(local_file, backup_file)

# 将航班转换为当前时间用于教程
def update_dates(file):
    shutil.copy(backup_file, file)
    conn = sqlite3.connect(file)
    cursor = conn.cursor()

    tables = pd.read_sql(
        "SELECT name FROM sqlite_master WHERE type='table';", conn
    ).name.tolist()
    tdf = {}
    for t in tables:
        tdf[t] = pd.read_sql(f"SELECT * from {t}", conn)

    example_time = pd.to_datetime(
        tdf["flights"]["actual_departure"].replace("\\N", pd.NaT)
    ).max()
    current_time = pd.to_datetime("now").tz_localize(example_time.tz)
    time_diff = current_time - example_time

    tdf["bookings"]["book_date"] = (
        pd.to_datetime(tdf["bookings"]["book_date"].replace("\\N", pd.NaT), utc=True)
        + time_diff
    )

    datetime_columns = [
        "scheduled_departure",
        "scheduled_arrival",
        "actual_departure",
        "actual_arrival",
    ]
    for column in datetime_columns:
        tdf["flights"][column] = (
            pd.to_datetime(tdf["flights"][column].replace("\\N", pd.NaT)) + time_diff
        )

    for table_name, df in tdf.items():
        df.to_sql(table_name, conn, if_exists="replace", index=False)
    del df
    del tdf
    conn.commit()
    conn.close()

    return file

db = update_dates(local_file)

工具定义

接下来,定义助手的工具来搜索航空公司的政策手册,以及搜索和管理航班、酒店、汽车租赁和游览的预订。我们将在整个教程中重复使用这些工具。确切的实现并不重要,因此请随意运行下面的代码并跳转到第一部分。

查找公司政策 助手检索政策信息来回答用户问题。请注意,这些政策的执行仍然必须在工具/API本身中完成,因为LLM总是可以忽略这些。

import re
import numpy as np
import openai
from langchain_core.tools import tool
response = requests.get(
    "https://storage.googleapis.com/benchmarks-artifacts/travel-db/swiss_faq.md"
)
response.raise_for_status()
faq_text = response.text
docs = [{"page_content": txt} for txt in re.split(r"(?=\n##)", faq_text)]
class VectorStoreRetriever:
    def __init__(self, docs: list, vectors: list, oai_client):
        self._arr = np.array(vectors)
        self._docs = docs
        self._client = oai_client
    @classmethod
    def from_docs(cls, docs, oai_client):
        embeddings = oai_client.embeddings.create(
            model="text-embedding-3-small", 
            input=[doc["page_content"] for doc in docs]
        )
        vectors = [emb.embedding for emb in embeddings.data]
        return cls(docs, vectors, oai_client)
    def query(self, query: str, k: int = 5) -> list[dict]:
        embed = self._client.embeddings.create(
            model="text-embedding-3-small", input=[query]
        )
        # "@" 在Python中只是矩阵乘法
        scores = np.array(embed.data[0].embedding) @ self._arr.T
        top_k_idx = np.argpartition(scores, -k)[-k:]
        top_k_idx_sorted = top_k_idx[np.argsort(-scores[top_k_idx])]
        return [
            {**self._docs[idx], "similarity": scores[idx]} 
            for idx in top_k_idx_sorted
        ]
retriever = VectorStoreRetriever.from_docs(docs, openai.Client())
@tool
def lookup_policy(query: str) -> str:
    """查询公司政策以检查是否允许某些选项。
    在进行任何航班更改或执行其他'写入'事件之前使用此工具。"""
    docs = retriever.query(query, k=2)
    return "\n\n".join([doc["page_content"] for doc in docs])

航班工具

定义 fetch_user_flight_information 工具让代理查看当前用户的航班信息。然后定义工具来搜索航班并管理存储在SQL数据库中的乘客预订。

我们然后可以访问给定运行的RunnableConfig来检查访问此应用程序的用户的passenger_id。LLM从不必须明确提供这些,它们是为图的给定调用提供的,这样每个用户就不能访问其他乘客的预订信息。需要注意的是教程期望 langchain-core>=0.2.16 以使用注入的RunnableConfig。在此之前,我们可以使用 ensure_config 从上下文中收集配置。

import sqlite3
from datetime import date, datetime
from typing import Optional
import pytz
from langchain_core.runnables import RunnableConfig

@tool
def fetch_user_flight_information(config: RunnableConfig) -> list[dict]:
    """获取用户的所有机票以及相应的航班信息和座位分配。

    返回:
        包含票务详情、相关航班详情和属于用户的每张票的座位分配的字典列表。
    """
    configuration = config.get("configurable", {})
    passenger_id = configuration.get("passenger_id", None)
    if not passenger_id:
        raise ValueError("未配置乘客ID。")

    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    query = """
    SELECT 
        t.ticket_no, t.book_ref,
        f.flight_id, f.flight_no, f.departure_airport, f.arrival_airport, 
        f.scheduled_departure, f.scheduled_arrival,
        bp.seat_no, tf.fare_conditions
    FROM 
        tickets t
        JOIN ticket_flights tf ON t.ticket_no = tf.ticket_no
        JOIN flights f ON tf.flight_id = f.flight_id
        JOIN boarding_passes bp ON bp.ticket_no = t.ticket_no AND bp.flight_id = f.flight_id
    WHERE 
        t.passenger_id = ?
    """
    cursor.execute(query, (passenger_id,))
    rows = cursor.fetchall()
    column_names = [column[0] for column in cursor.description]
    results = [dict(zip(column_names, row)) for row in rows]

    cursor.close()
    conn.close()
    return results

@tool
def search_flights(
    departure_airport: Optional[str] = None,
    arrival_airport: Optional[str] = None,
    start_time: Optional[date | datetime] = None,
    end_time: Optional[date | datetime] = None,
    limit: int = 20,
) -> list[dict]:
    """根据出发机场、到达机场和出发时间范围搜索航班。"""
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    query = "SELECT * FROM flights WHERE 1 = 1"
    params = []

    if departure_airport:
        query += " AND departure_airport = ?"
        params.append(departure_airport)

    if arrival_airport:
        query += " AND arrival_airport = ?"
        params.append(arrival_airport)

    if start_time:
        query += " AND scheduled_departure >= ?"
        params.append(start_time)

    if end_time:
        query += " AND scheduled_departure <= ?"
        params.append(end_time)
    query += " LIMIT ?"
    params.append(limit)
    cursor.execute(query, params)
    rows = cursor.fetchall()
    column_names = [column[0] for column in cursor.description]
    results = [dict(zip(column_names, row)) for row in rows]

    cursor.close()
    conn.close()
    return results

@tool
def update_ticket_to_new_flight(
    ticket_no: str, new_flight_id: int, *, config: RunnableConfig
) -> str:
    """将用户的机票更新为新的有效航班。"""
    configuration = config.get("configurable", {})
    passenger_id = configuration.get("passenger_id", None)
    if not passenger_id:
        raise ValueError("未配置乘客ID。")

    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    cursor.execute(
        "SELECT departure_airport, arrival_airport, scheduled_departure FROM flights WHERE flight_id = ?",
        (new_flight_id,),
    )
    new_flight = cursor.fetchone()
    if not new_flight:
        cursor.close()
        conn.close()
        return "提供的新航班ID无效。"
    column_names = [column[0] for column in cursor.description]
    new_flight_dict = dict(zip(column_names, new_flight))
    timezone = pytz.timezone("Etc/GMT-3")
    current_time = datetime.now(tz=timezone)
    departure_time = datetime.strptime(
        new_flight_dict["scheduled_departure"], "%Y-%m-%d %H:%M:%S.%f%z"
    )
    time_until = (departure_time - current_time).total_seconds()
    if time_until < (3 * 3600):
        return f"不允许改签到距离当前时间不足3小时的航班。所选航班时间为 {departure_time}。"

    cursor.execute(
        "SELECT flight_id FROM ticket_flights WHERE ticket_no = ?", (ticket_no,)
    )
    current_flight = cursor.fetchone()
    if not current_flight:
        cursor.close()
        conn.close()
        return "未找到给定票号的现有机票。"

    # 检查登录用户是否确实拥有此机票
    cursor.execute(
        "SELECT * FROM tickets WHERE ticket_no = ? AND passenger_id = ?",
        (ticket_no, passenger_id),
    )
    current_ticket = cursor.fetchone()
    if not current_ticket:
        cursor.close()
        conn.close()
        return f"当前登录乘客ID {passenger_id} 不是机票 {ticket_no} 的所有者"

    # 在实际应用中,您可能会在这里添加额外的检查来执行业务逻辑,
    # 比如"新的出发机场是否与当前机票匹配"等等。
    # 虽然最好尝试向LLM主动"类型提示"政策
    # 但它不可避免地会出错,所以您**也**需要确保您的
    # API强制执行有效行为
    cursor.execute(
        "UPDATE ticket_flights SET flight_id = ? WHERE ticket_no = ?",
        (new_flight_id, ticket_no),
    )
    conn.commit()

    cursor.close()
    conn.close()
    return "机票已成功更新为新航班。"

@tool
def cancel_ticket(ticket_no: str, *, config: RunnableConfig) -> str:
    """取消用户的机票并从数据库中删除。"""
    configuration = config.get("configurable", {})
    passenger_id = configuration.get("passenger_id", None)
    if not passenger_id:
        raise ValueError("未配置乘客ID。")
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    cursor.execute(
        "SELECT flight_id FROM ticket_flights WHERE ticket_no = ?", (ticket_no,)
    )
    existing_ticket = cursor.fetchone()
    if not existing_ticket:
        cursor.close()
        conn.close()
        return "未找到给定票号的现有机票。"

    # 检查登录用户是否确实拥有此机票
    cursor.execute(
        "SELECT ticket_no FROM tickets WHERE ticket_no = ? AND passenger_id = ?",
        (ticket_no, passenger_id),
    )
    current_ticket = cursor.fetchone()
    if not current_ticket:
        cursor.close()
        conn.close()
        return f"当前登录乘客ID {passenger_id} 不是机票 {ticket_no} 的所有者"

    cursor.execute("DELETE FROM ticket_flights WHERE ticket_no = ?", (ticket_no,))
    conn.commit()

    cursor.close()
    conn.close()
    return "机票已成功取消。"

汽车租赁工具

用户预订航班后,可能需要安排交通。定义一些"汽车租赁"工具,让用户搜索并在目的地预订汽车。

from datetime import date, datetime
from typing import Optional, Union

@tool
def search_car_rentals(
    location: Optional[str] = None,
    name: Optional[str] = None,
    price_tier: Optional[str] = None,
    start_date: Optional[Union[datetime, date]] = None,
    end_date: Optional[Union[datetime, date]] = None,
) -> list[dict]:
    """
    根据位置、名称、价格等级、开始日期和结束日期搜索汽车租赁。

    参数:
        location (Optional[str]): 汽车租赁的位置。默认为None。
        name (Optional[str]): 汽车租赁公司的名称。默认为None。
        price_tier (Optional[str]): 汽车租赁的价格等级。默认为None。
        start_date (Optional[Union[datetime, date]]): 汽车租赁的开始日期。默认为None。
        end_date (Optional[Union[datetime, date]]): 汽车租赁的结束日期。默认为None。

    返回:
        list[dict]: 符合搜索条件的汽车租赁字典列表。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    query = "SELECT * FROM car_rentals WHERE 1=1"
    params = []

    if location:
        query += " AND location LIKE ?"
        params.append(f"%{location}%")
    if name:
        query += " AND name LIKE ?"
        params.append(f"%{name}%")
    # 对于我们的教程,我们将允许您匹配任何日期和价格等级。
    # (因为我们的玩具数据集没有太多数据)
    cursor.execute(query, params)
    results = cursor.fetchall()

    conn.close()

    return [
        dict(zip([column[0] for column in cursor.description], row)) for row in results
    ]

@tool
def book_car_rental(rental_id: int) -> str:
    """
    通过ID预订汽车租赁。

    参数:
        rental_id (int): 要预订的汽车租赁的ID。

    返回:
        str: 指示汽车租赁是否成功预订的消息。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    cursor.execute("UPDATE car_rentals SET booked = 1 WHERE id = ?", (rental_id,))
    conn.commit()

    if cursor.rowcount > 0:
        conn.close()
        return f"汽车租赁 {rental_id} 已成功预订。"
    else:
        conn.close()
        return f"未找到ID为 {rental_id} 的汽车租赁。"

@tool
def update_car_rental(
    rental_id: int,
    start_date: Optional[Union[datetime, date]] = None,
    end_date: Optional[Union[datetime, date]] = None,
) -> str:
    """
    通过ID更新汽车租赁的开始和结束日期。

    参数:
        rental_id (int): 要更新的汽车租赁的ID。
        start_date (Optional[Union[datetime, date]]): 汽车租赁的新开始日期。默认为None。
        end_date (Optional[Union[datetime, date]]): 汽车租赁的新结束日期。默认为None。

    返回:
        str: 指示汽车租赁是否成功更新的消息。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    if start_date:
        cursor.execute(
            "UPDATE car_rentals SET start_date = ? WHERE id = ?",
            (start_date, rental_id),
        )
    if end_date:
        cursor.execute(
            "UPDATE car_rentals SET end_date = ? WHERE id = ?", (end_date, rental_id)
        )

    conn.commit()

    if cursor.rowcount > 0:
        conn.close()
        return f"汽车租赁 {rental_id} 已成功更新。"
    else:
        conn.close()
        return f"未找到ID为 {rental_id} 的汽车租赁。"

@tool
def cancel_car_rental(rental_id: int) -> str:
    """
    通过ID取消汽车租赁。

    参数:
        rental_id (int): 要取消的汽车租赁的ID。

    返回:
        str: 指示汽车租赁是否成功取消的消息。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    cursor.execute("UPDATE car_rentals SET booked = 0 WHERE id = ?", (rental_id,))
    conn.commit()

    if cursor.rowcount > 0:
        conn.close()
        return f"汽车租赁 {rental_id} 已成功取消。"
    else:
        conn.close()
        return f"未找到ID为 {rental_id} 的汽车租赁。"
酒店工具
用户需要住宿!定义一些工具来搜索和管理酒店预订。

@tool
def search_hotels(
    location: Optional[str] = None,
    name: Optional[str] = None,
    price_tier: Optional[str] = None,
    checkin_date: Optional[Union[datetime, date]] = None,
    checkout_date: Optional[Union[datetime, date]] = None,
) -> list[dict]:
    """
    根据位置、名称、价格等级、入住日期和退房日期搜索酒店。

    参数:
        location (Optional[str]): 酒店的位置。默认为None。
        name (Optional[str]): 酒店的名称。默认为None。
        price_tier (Optional[str]): 酒店的价格等级。默认为None。示例:中档、高档中档、高端、豪华
        checkin_date (Optional[Union[datetime, date]]): 酒店的入住日期。默认为None。
        checkout_date (Optional[Union[datetime, date]]): 酒店的退房日期。默认为None。

    返回:
        list[dict]: 符合搜索条件的酒店字典列表。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    query = "SELECT * FROM hotels WHERE 1=1"
    params = []

    if location:
        query += " AND location LIKE ?"
        params.append(f"%{location}%")
    if name:
        query += " AND name LIKE ?"
        params.append(f"%{name}%")
    # 为了本教程的目的,我们将允许您匹配任何日期和价格等级。
    cursor.execute(query, params)
    results = cursor.fetchall()

    conn.close()

    return [
        dict(zip([column[0] for column in cursor.description], row)) for row in results
    ]

@tool
def book_hotel(hotel_id: int) -> str:
    """
    通过ID预订酒店。

    参数:
        hotel_id (int): 要预订的酒店的ID。

    返回:
        str: 指示酒店是否成功预订的消息。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    cursor.execute("UPDATE hotels SET booked = 1 WHERE id = ?", (hotel_id,))
    conn.commit()

    if cursor.rowcount > 0:
        conn.close()
        return f"酒店 {hotel_id} 已成功预订。"
    else:
        conn.close()
        return f"未找到ID为 {hotel_id} 的酒店。"

@tool
def update_hotel(
    hotel_id: int,
    checkin_date: Optional[Union[datetime, date]] = None,
    checkout_date: Optional[Union[datetime, date]] = None,
) -> str:
    """
    通过ID更新酒店的入住和退房日期。

    参数:
        hotel_id (int): 要更新的酒店的ID。
        checkin_date (Optional[Union[datetime, date]]): 酒店的新入住日期。默认为None。
        checkout_date (Optional[Union[datetime, date]]): 酒店的新退房日期。默认为None。

    返回:
        str: 指示酒店是否成功更新的消息。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    if checkin_date:
        cursor.execute(
            "UPDATE hotels SET checkin_date = ? WHERE id = ?", (checkin_date, hotel_id)
        )
    if checkout_date:
        cursor.execute(
            "UPDATE hotels SET checkout_date = ? WHERE id = ?",
            (checkout_date, hotel_id),
        )

    conn.commit()

    if cursor.rowcount > 0:
        conn.close()
        return f"酒店 {hotel_id} 已成功更新。"
    else:
        conn.close()
        return f"未找到ID为 {hotel_id} 的酒店。"

@tool
def cancel_hotel(hotel_id: int) -> str:
    """
    通过ID取消酒店。

    参数:
        hotel_id (int): 要取消的酒店的ID。

    返回:
        str: 指示酒店是否成功取消的消息。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    cursor.execute("UPDATE hotels SET booked = 0 WHERE id = ?", (hotel_id,))
    conn.commit()

    if cursor.rowcount > 0:
        conn.close()
        return f"酒店 {hotel_id} 已成功取消。"
    else:
        conn.close()
        return f"未找到ID为 {hotel_id} 的酒店。"
游览工具
最后,定义一些工具让用户搜索到达后的活动(并进行预订)。

@tool
def search_trip_recommendations(
    location: Optional[str] = None,
    name: Optional[str] = None,
    keywords: Optional[str] = None,
) -> list[dict]:
    """
    根据位置、名称和关键词搜索旅行推荐。

    参数:
        location (Optional[str]): 旅行推荐的位置。默认为None。
        name (Optional[str]): 旅行推荐的名称。默认为None。
        keywords (Optional[str]): 与旅行推荐相关的关键词。默认为None。

    返回:
        list[dict]: 符合搜索条件的旅行推荐字典列表。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    query = "SELECT * FROM trip_recommendations WHERE 1=1"
    params = []

    if location:
        query += " AND location LIKE ?"
        params.append(f"%{location}%")
    if name:
        query += " AND name LIKE ?"
        params.append(f"%{name}%")
    if keywords:
        keyword_list = keywords.split(",")
        keyword_conditions = " OR ".join(["keywords LIKE ?" for _ in keyword_list])
        query += f" AND ({keyword_conditions})"
        params.extend([f"%{keyword.strip()}%" for keyword in keyword_list])

    cursor.execute(query, params)
    results = cursor.fetchall()

    conn.close()

    return [
        dict(zip([column[0] for column in cursor.description], row)) for row in results
    ]

@tool
def book_excursion(recommendation_id: int) -> str:
    """
    通过推荐ID预订游览。

    参数:
        recommendation_id (int): 要预订的旅行推荐的ID。

    返回:
        str: 指示旅行推荐是否成功预订的消息。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    cursor.execute(
        "UPDATE trip_recommendations SET booked = 1 WHERE id = ?", (recommendation_id,)
    )
    conn.commit()

    if cursor.rowcount > 0:
        conn.close()
        return f"旅行推荐 {recommendation_id} 已成功预订。"
    else:
        conn.close()
        return f"未找到ID为 {recommendation_id} 的旅行推荐。"

@tool
def update_excursion(recommendation_id: int, details: str) -> str:
    """
    通过ID更新旅行推荐的详情。

    参数:
        recommendation_id (int): 要更新的旅行推荐的ID。
        details (str): 旅行推荐的新详情。

    返回:
        str: 指示旅行推荐是否成功更新的消息。
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    cursor.execute(
        "UPDATE trip_recommendations SET details = ? WHERE id = ?",
        (details, recommendation_id),
    )
    conn.commit()

    if cursor.rowcount > 0:
        conn.close()
        return f"旅行推荐 {recommendation_id} 已成功更新。"
    else:
        conn.close()
        return f"未找到ID为 {recommendation_id} 的旅行推荐。"

实用工具

定义辅助函数来在调试时美化打印图中的消息,并为工具节点提供错误处理。

from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda
from langgraph.prebuilt import ToolNode
def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }
def create_tool_node_with_fallback(tools: list) -> dict:
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )
def _print_event(event: dict, _printed: set, max_length=1500):
    current_state = event.get("dialog_state")
    if current_state:
        print("Currently in: ", current_state[-1])
    message = event.get("messages")
    if message:
        if isinstance(message, list):
            message = message[-1]
        if message.id not in _printed:
            msg_repr = message.pretty_repr(html=True)
            if len(msg_repr) > max_length:
                msg_repr = msg_repr[:max_length] + " ... (truncated)"
            print(msg_repr)
            _printed.add(message.id)

第一部分:零样本代理

在本节中,我们将定义一个简单的零样本代理作为助手,为代理提供所有工具,并提示它明智地使用这些工具来协助用户。

简单的2节点图如下所示: 在这里插入图片描述

状态

将StateGraph的状态定义为包含仅追加消息列表的类型化字典。这些消息形成聊天历史记录,这是我们简单助手所需的全部状态。

from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage, add_messages

class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
代理
接下来,定义助手函数。该函数接受图状态,将其格式化为提示,然后调用LLM来预测最佳响应。

from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig

class Assistant:
    def __init__(self, runnable: Runnable):
        self.runnable = runnable

    def __call__(self, state: State, config: RunnableConfig):
        while True:
            configuration = config.get("configurable", {})
            passenger_id = configuration.get("passenger_id", None)
            state = {**state, "user_info": passenger_id}
            result = self.runnable.invoke(state)
            # 如果LLM返回空响应,我们将重新提示它给出实际响应
            if not result.tool_calls and (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "请给出真实的输出。")]
                state = {**state, "messages": messages}
            else:
                break
        return {"messages": result}

# Haiku更快更便宜,但不太准确
# llm = ChatAnthropic(model="claude-3-haiku-20240307")
llm = ChatOpenAI(model="gpt-4o", temperature=1)

primary_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "您是瑞士航空公司的有用客户支持助手。"
            " 使用提供的工具搜索航班、公司政策和其他信息来协助用户查询。"
            " 搜索时要持之以恒。如果第一次搜索没有返回结果,请扩大查询范围。"
            " 如果搜索结果为空,请在放弃之前扩大搜索范围。"
            "\n\n当前用户:\n<User>\n{user_info}\n</User>"
            "\n当前时间: {time}。",
        ),
        ("placeholder", "{messages}"),
    ]
).partial(time=datetime.now)

part_1_tools = [
    TavilySearchResults(max_results=1),
    fetch_user_flight_information,
    search_flights,
    lookup_policy,
    update_ticket_to_new_flight,
    cancel_ticket,
    search_car_rentals,
    book_car_rental,
    update_car_rental,
    cancel_car_rental,
    search_hotels,
    book_hotel,
    update_hotel,
    cancel_hotel,
    search_trip_recommendations,
    book_excursion,
    update_excursion,
    cancel_excursion,
]

part_1_assistant_runnable = primary_assistant_prompt | llm.bind_tools(part_1_tools)

定义图

现在创建图。图是本节的最终助手。

from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import tools_condition

builder = StateGraph(State)

# 定义节点:执行工作
builder.add_node("assistant", Assistant(part_1_assistant_runnable))
builder.add_node("tools", create_tool_node_with_fallback(part_1_tools))

# 定义边:确定控制流如何移动
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    tools_condition,
)
builder.add_edge("tools", "assistant")

# 检查点器让图持久化其状态
# 这是整个图的完整内存
memory = InMemorySaver()
part_1_graph = builder.compile(checkpointer=memory)
示例对话
现在是时候试试我们强大的聊天机器人了!让我们在以下对话轮次列表上运行它。如果遇到"RecursionLimit",这意味着代理无法在分配的步骤数内得到答案。没关系!在本教程的后续部分中,我们还有更多方法。

import shutil
import uuid

# 创建用户可能与助手进行的示例对话
tutorial_questions = [
    "您好,我的航班是什么时间?",
    "我可以将航班更新为更早的时间吗?我想今天晚些时候离开。",
    "那么将我的航班更新为下周某个时间",
    "下一个可用选项很好",
    "住宿和交通怎么样?",
    "是的,我想要一个经济实惠的酒店,住一周(7天)。我还想租一辆车。",
    "好的,您能为您推荐的酒店预订吗?听起来不错。",
    "是的,请继续预订任何费用适中且有空房的酒店。",
    "现在关于汽车,我有什么选择?",
    "太棒了,让我们选择最便宜的选项。请预订7天",
    "很好,现在您对游览有什么建议?",
    "我在那里的时候有这些活动吗?",
    "有趣 - 我喜欢博物馆,有什么选择?",
    "好的,很好,选择一个并为我在那里的第二天预订。",
]

# 使用备份文件更新,这样我们可以在每个部分从原始位置重新开始
db = update_dates(db)
thread_id = str(uuid.uuid4())

config = {
    "configurable": {
        # passenger_id用于我们的航班工具来获取用户的航班信息
        "passenger_id": "3442 587242",
        # 检查点通过thread_id访问
        "thread_id": thread_id,
    }
}

_printed = set()
for question in tutorial_questions:
    events = part_1_graph.stream(
        {"messages": ("user", question)}, config, stream_mode="values"
    )
    for event in events:
        _print_event(event, _printed)

第一部分回顾

我们的简单助手还不错!它能够合理地回应所有问题,快速地在上下文中响应,并成功执行所有任务。

如果这是一个简单的问答机器人,我们可能会对上述结果感到满意。由于我们的客户支持机器人代表用户采取行动,上述的一些行为有点令人担忧:

未经确认的预订,助手在我们专注于住宿时预订了汽车,然后不得不稍后取消和重新预订:糟糕!用户应该在预订前有最终决定权,以避免不需要的费用。 搜索困难,助手在搜索推荐时遇到困难。我们可以通过为工具添加更详细的说明和示例来改进这一点,但为每个工具这样做可能导致大型提示和不知所措的代理。 效率问题,助手必须进行明确搜索才能获得用户的相关信息。我们可以通过立即获取用户的相关旅行详情来节省大量时间,这样助手就可以直接回应。

在下一节中,我们将解决前两个问题。

总结

本教程的第一部分展示了如何构建一个基本的客户支持机器人,后续部分将通过更高级的LangGraph功能(如中断和检查点)来解决这些问题,创建更可靠和用户友好的客户支持体验。这个教程为构建复杂的AI客户支持系统提供了坚实的基础,展示了如何将多个工具和服务集成到一个连贯的对话界面中。