LlamaIndex(二):WorkFlow

296 阅读10分钟

WorkFlow的基本使用

WorkFlow,也就是工作流,作为程序员对这个东西并不陌生,因为写代码的过程中有很大一部分就是处理各种流程,对各种流程进行抽象。

假设现在有一个需求,整体流程是这样的

图片.png

任务开始执行事件A,然后并发执行事件B、C,等事件B、C全部执行完毕后,执行事件D,然后流程结束,并不复杂对不对?

但是在你绞尽脑汁刚刚写好代码,处理好并发问题,测试完毕,产品修改了需求,事件B、C,不并发执行了,改成顺序执行。。。

在LlamaIndex中这种修改非常方便,看下面的代码

这是初版需求的工作流

from llama_index.core.workflow import Workflow, Event, StartEvent, StopEvent, Context, step, draw_all_possible_flows


class EventA(Event):
    pass


class EventB(Event):
    pass

class EventC(Event):
    pass

class EventD(Event):
    pass


class WorkFlowTest(Workflow):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


    @step
    def step_a(self,ctx:Context,ev:StartEvent) -> EventA:
        return EventA()


    @step
    def step_b(self,ctx:Context,ev:EventA) -> EventB:
        return EventB()

    @step
    def step_c(self, ctx: Context, ev: EventA) -> EventC:
        return EventC()


    @step
    def step_d(self, ctx: Context, ev: EventB|EventC) -> StopEvent:
        return StopEvent()


这是修改之后的,仅仅改动了一下参数类型和返回值

from llama_index.core.workflow import Workflow, Event, StartEvent, StopEvent, Context, step, draw_all_possible_flows


class EventA(Event):
    pass


class EventB(Event):
    pass

class EventC(Event):
    pass

class EventD(Event):
    pass


class WorkFlowTest(Workflow):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


    @step
    def step_a(self,ctx:Context,ev:StartEvent) -> EventA:
        return EventA()


    @step
    def step_b(self,ctx:Context,ev:EventA) -> EventB:
        return EventB()

    @step
    def step_c(self, ctx: Context, ev: EventB) -> EventC:
        return EventC()


    @step
    def step_d(self, ctx: Context, ev: EventC) -> StopEvent:
        return StopEvent()

图片.png

是不是非常方便!

在LlamaIndex中,工作流是由step组成的,从StartEvent开始,每个step会处理特定的Event,处理完毕后触发下一个事件交给下一个step处理,直到产生StopEvent

当然在LlamaIndex的工作流中不止能处理这种简单的顺序执行的场景,还支持循环,并发等多种情况

循环工作流

循环其实也非常简单,只需要定义一个Loop事件就可以了,根据官方文档,事件可以拥有任何自定的名称 图片.png

import asyncio

from llama_index.core.workflow import Workflow, Event, StartEvent, StopEvent, Context, step, draw_all_possible_flows


class EventA(Event):
    msg:str


class EventB(Event):
    msg:str

class EventC(Event):
    msg:str

class LoopEvent(Event):
    msg:str


class WorkFlowTest(Workflow):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.count = 5
        self.loop_num = 1


    @step
    async def step_a(self,ctx:Context,ev:StartEvent) -> EventA:
        print(ev.msg)
        print('我是步骤a')
        return EventA(msg='事件A')


    @step
    async def step_b(self,ctx:Context,ev:EventA|LoopEvent) -> EventB|LoopEvent:
        print(f"我是步骤b,由{ev.msg}触发,循环次数{self.count} ,当前第{self.loop_num}次")
        self.loop_num +=1
        if self.loop_num <= self.count:
            return LoopEvent(msg="循环事件")
        return EventB(msg="事件b")



    @step
    async def step_c(self, ctx: Context, ev: EventB) -> EventC:
        print(f"我是步骤c,由{ev.msg}触发")
        return EventC(msg="事件c")


    @step
    async def step_d(self, ctx: Context, ev: EventC) -> StopEvent:
        print(f"我是步骤d,由{ev.msg}触发")
        return StopEvent(result="流程结束")

workflow = WorkFlowTest()

async def task(lock):
    async with lock:
        response = await workflow.run(
            msg="流程开始"
        )
        print(str(response))



async def main():
    lock = asyncio.Lock()
    tasks = [task(lock)]
    await asyncio.gather(*tasks)

if __name__ == "__main__":
    asyncio.run(main())

图片.png

Context

在提并发工作流之前,先说明一下步骤和步骤之间是怎么传递数据的,在上面的代码中,可以注意到有一个Context参数,这个参数就是用来传递数据用的,这个有点类似于Go语言里的Context,可以用于上下文之间的数据传递,而不是依靠Event,将数据在所有步骤间到处传递

class SetupEvent(Event):
    query: str


class StepTwoEvent(Event):
    query: str


class StatefulFlow(Workflow):
    @step
    async def start(
        self, ctx: Context, ev: StartEvent
    ) -> SetupEvent | StepTwoEvent:
        db = await ctx.get("some_database", default=None)
        if db is None:
            print("Need to load data")
            return SetupEvent(query=ev.query)

        # do something with the query
        return StepTwoEvent(query=ev.query)

    @step
    async def setup(self, ctx: Context, ev: SetupEvent) -> StartEvent:
        # load data
        await ctx.set("some_database", [1, 2, 3])
        return StartEvent(query=ev.query)

并发工作流

并发工作流,也需要这个Context,具体可以看代码

图片.png

import asyncio

from llama_index.core.workflow import Workflow, Event, StartEvent, StopEvent, Context, step, draw_all_possible_flows


class EventA(Event):
    msg:str


class EventB(Event):
    msg:str

class EventC(Event):
    msg:str

class LoopEvent(Event):
    msg:str


class WorkFlowTest(Workflow):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.count = 5
        self.loop_num = 1


    @step
    async def step_a(self,ctx:Context,ev:StartEvent) -> EventA:
        print(ev.msg)
        print('我是步骤a')
        return EventA(msg='事件A')


    @step
    async def step_b(self,ctx:Context,ev:EventA) -> EventB|EventC:
        ctx.send_event(EventB(msg='事件b'))
        ctx.send_event(EventC(msg='事件b'))



    @step
    async def step_c(self, ctx: Context, ev: EventB) -> StopEvent:
        print(f"我是步骤c,由{ev.msg}触发,我和d同时执行")
        return StopEvent(result="流程结束")


    @step
    async def step_d(self, ctx: Context, ev: EventC) -> StopEvent:
        print(f"我是步骤d,由{ev.msg}触发,我和c同时执行")
        return StopEvent(result="流程结束")

workflow = WorkFlowTest()

async def task(lock):
    async with lock:
        response = await workflow.run(
            msg="流程开始"
        )
        print(str(response))



async def main():
    lock = asyncio.Lock()
    tasks = [task(lock)]
    res = await asyncio.gather(*tasks)
if __name__ == "__main__":
    asyncio.run(main())

图片.png

关于工作流还有更多的内容,具体可以参考官方文档

案例:使用自然语言进行数据库查询

需求说明

  1. 用户输入自然语言查询
  2. 系统先去检索跟查询相关的表
  3. 根据表的 Schema 让大模型生成 SQL
  4. 用生成的 SQL 查询数据库
  5. 根据查询结果,调用大模型生成自然语言回复

数据准备

# 下载 WikiTableQuestions
# WikiTableQuestions 是一个为表格问答设计的数据集。其中包含 2,108 个从维基百科提取的 HTML 表格

wget "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip" -O wiki_data.zip
unzip wiki_data.zip

数据入库

遍历目录加载表格

import pandas as pd
from pathlib import Path

data_dir = Path("./WikiTableQuestions/csv/200-csv")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []
for csv_file in csv_files:
    print(f"processing file: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        dfs.append(df)
    except Exception as e:
        print(f"Error parsing {csv_file}: {str(e)}")

图片.png

为每个表生成一段文字描述,保存到 WikiTableQuestions_TableInfo 目录

import os
import json

from llama_index.core.prompts import ChatPromptTemplate
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage

tableinfo_dir = "WikiTableQuestions_TableInfo"

if not os.path.exists(tableinfo_dir):
    os.mkdir(tableinfo_dir)
    

class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


prompt_str = """\
Give me a summary of the table with the following JSON format.

- The table name must be unique to the table and describe it while being concise. 
- Do NOT output a generic table name (e.g. table, my_table).

Do NOT make the table name one of the following: {exclude_table_name_list}

Table:
{table_str}

Summary: """
prompt_tmpl = ChatPromptTemplate(
    message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)

llm = OpenAI(model="gpt-4o-mini")

def _get_tableinfo_with_index(idx: int) -> str:
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    results_list = list(results_gen)
    if len(results_list) == 0:
        return None
    elif len(results_list) == 1:
        path = results_list[0]
        with open(path, 'r') as file:
            data = json.load(file) 
            return TableInfo.model_validate(data)
    else:
        raise ValueError(
            f"More than one file matching index: {list(results_gen)}"
        )


table_names = set()
table_infos = []
for idx, df in enumerate(dfs):
    table_info = _get_tableinfo_with_index(idx)
    if table_info:
        table_infos.append(table_info)
    else:
        while True:
            df_str = df.head(10).to_csv()
            table_info = llm.structured_predict(
                TableInfo,
                prompt_tmpl,
                table_str=df_str,
                exclude_table_name_list=str(list(table_names)),
            )
            table_name = table_info.table_name
            print(f"Processed table: {table_name}")
            if table_name not in table_names:
                table_names.add(table_name)
                break
            else:
                # try again
                print(f"Table name {table_name} already exists, trying again.")
                pass

        out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
        json.dump(table_info.dict(), open(out_file, "w"))
    table_infos.append(table_info)

将上面的表格写入sqllite

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)
import re


# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)


# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()


# engine = create_engine("sqlite:///:memory:")
engine = create_engine("sqlite:///wiki_table_questions.db")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
    tableinfo = _get_tableinfo_with_index(idx)
    print(f"Creating table: {tableinfo.table_name}")
    create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)

图片.png

基础工具构建

构建向量索引

from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import SQLDatabase, VectorStoreIndex

sql_database = SQLDatabase(engine)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
    for t in table_infos
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)

创建sql查询器

from llama_index.core.retrievers import SQLRetriever
from typing import List

sql_retriever = SQLRetriever(sql_database)


def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)

创建Text2SQL提示词(系统默认模板)和输出结果解析器(从生成的文本中抽取SQL)

from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatResponse

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Parse response to SQL."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        # TODO: move to removeprefix after Python 3.9+
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()


text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)
print(text2sql_prompt.template)

图片.png

创建自然语言回复模板

response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n"
    "SQL: {sql_query}\n"
    "SQL Response: {context_str}\n"
    "Response: "
)
response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

定义工作流

from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    step,
    Context,
    Event,
)

# 事件:找到了数据库中相关的表
class TableRetrieveEvent(Event):
    """Result of running table retrieval."""

    table_context_str: str
    query: str

# 事件:文本转为了 SQL
class TextToSQLEvent(Event):
    """Text-to-SQL event."""

    sql: str
    query: str


class TextToSQLWorkflow1(Workflow):
    """Text-to-SQL Workflow that does query-time table retrieval."""

    def __init__(
        self,
        obj_retriever,
        text2sql_prompt,
        sql_retriever,
        response_synthesis_prompt,
        llm,
        *args,
        **kwargs
    ) -> None:
        """Init params."""
        super().__init__(*args, **kwargs)
        self.obj_retriever = obj_retriever
        self.text2sql_prompt = text2sql_prompt
        self.sql_retriever = sql_retriever
        self.response_synthesis_prompt = response_synthesis_prompt
        self.llm = llm

    @step
    def retrieve_tables(
        self, ctx: Context, ev: StartEvent
    ) -> TableRetrieveEvent:
        """Retrieve tables."""
        table_schema_objs = self.obj_retriever.retrieve(ev.query)
        table_context_str = get_table_context_str(table_schema_objs)
        print("====\n"+table_context_str+"\n====")
        return TableRetrieveEvent(
            table_context_str=table_context_str, query=ev.query
        )

    @step
    def generate_sql(
        self, ctx: Context, ev: TableRetrieveEvent
    ) -> TextToSQLEvent:
        """Generate SQL statement."""
        fmt_messages = self.text2sql_prompt.format_messages(
            query_str=ev.query, schema=ev.table_context_str
        )
        chat_response = self.llm.chat(fmt_messages)
        sql = parse_response_to_sql(chat_response)
        print("====\n"+sql+"\n====")
        return TextToSQLEvent(sql=sql, query=ev.query)

    @step
    def generate_response(self, ctx: Context, ev: TextToSQLEvent) -> StopEvent:
        """Run SQL retrieval and generate response."""
        retrieved_rows = self.sql_retriever.retrieve(ev.sql)
        print("====\n"+str(retrieved_rows)+"\n====")
        fmt_messages = self.response_synthesis_prompt.format_messages(
            sql_query=ev.sql,
            context_str=str(retrieved_rows),
            query_str=ev.query,
        )
        chat_response = llm.chat(fmt_messages)
        return StopEvent(result=chat_response)

运行

async def task(lock):
    engine, table_infos = create_table()
    obj_retriever, sql_database = create_retriever(engine, table_infos)
    sql_retriever = SQLRetriever(sql_database)

    text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
        dialect=engine.dialect.name
    )
    response_synthesis_prompt_str = (
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )
    response_synthesis_prompt = PromptTemplate(
        response_synthesis_prompt_str,
    )

    workflow = TextToSQLWorkflow1(
        obj_retriever,
        text2sql_prompt,
        sql_retriever,
        sql_database,
        response_synthesis_prompt,
        llm,
        verbose=True,
    )
    async with lock:

        response = await workflow.run(
            query="What was the year that The Notorious B.I.G was signed to Bad Boy?"
        )
        print(str(response))


async def main():
    lock = asyncio.Lock()
    tasks = [task(lock)]
    await asyncio.gather(*tasks)

if __name__ == "__main__":
    asyncio.run(main())
  

图片.png

可视化工作流

from llama_index.utils.workflow import draw_all_possible_flows

draw_all_possible_flows(
    workflow, filename="text_to_sql_table_retrieval.html"
)

图片.png

完整代码

import asyncio
import json
import os
import re
import traceback
from pathlib import Path
from typing import List

import nest_asyncio
import openai
import pandas as pd
from llama_index.core import PromptTemplate
from llama_index.core import SQLDatabase, VectorStoreIndex
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.llms import ChatMessage
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    step,
    Context,
    Event,
)
from llama_index.llms.openai import OpenAI
# put data into sqlite db
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)

openai.base_url = 'your api url'
openai.api_key = 'your api key'

prompt_str = """\
Give me a summary of the table with the following JSON format.

- The table name must be unique to the table and describe it while being concise. 
- Do NOT output a generic table name (e.g. table, my_table).

Do NOT make the table name one of the following: {exclude_table_name_list}

Table:
{table_str}

Summary: """

prompt_tmpl = ChatPromptTemplate(
    message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)

llm = OpenAI(model="gpt-4o-mini",temperature=0)
def process_file():
    data_dir = Path('./WikiTableQuestions/csv/200-csv')
    csv_file = sorted([i for i in data_dir.glob('*.csv')])
    dfs = []
    for csv_file in csv_file:
        print(f'正在处理文件:{csv_file}')
        try:
            df = pd.read_csv(csv_file)
            dfs.append(df)
        except Exception as e:
            print(f'处理文件:{csv_file} 失败:{traceback.format_exc()}')
    return dfs

class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


def _get_tableinfo_with_index(idx: int,tableinfo_dir) -> TableInfo:
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    results_list = list(results_gen)
    if len(results_list) == 0:
        return None
    elif len(results_list) == 1:
        path = results_list[0]
        with open(path, 'r') as file:
            data = json.load(file)
            return TableInfo.model_validate(data)
    else:
        raise ValueError(
            f"More than one file matching index: {list(results_gen)}"
        )


# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)


# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()

def write_table_info_dir(dfs,tableinfo_dir):
    table_names = set()
    table_infos = []
    for idx, df in enumerate(dfs):
        table_info = _get_tableinfo_with_index(idx,tableinfo_dir)
        if table_info:
            table_infos.append(table_info)
        else:
            while True:
                df_str = df.head(10).to_csv()
                table_info = llm.structured_predict(
                    TableInfo,
                    prompt_tmpl,
                    table_str=df_str,
                    exclude_table_name_list=str(list(table_names)),
                )
                table_name = table_info.table_name
                print(f"Processed table: {table_name}")
                if table_name not in table_names:
                    table_names.add(table_name)
                    break
                else:
                    # try again
                    print(f"Table name {table_name} already exists, trying again.")
                    pass

            out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
            json.dump(table_info.dict(), open(out_file, "w"))
        table_infos.append(table_info)
    return table_infos

# engine = create_engine("sqlite:///:memory:")
def create_table():
    tableinfo_dir = "WikiTableQuestions_TableInfo"
    os.makedirs(tableinfo_dir, exist_ok=True)
    if os.path.exists("wiki_table_questions.db"):
        os.remove("wiki_table_questions.db")
    engine = create_engine("sqlite:///wiki_table_questions.db",echo=True)
    metadata_obj = MetaData()
    dfs = process_file()
    tables_infos = write_table_info_dir(dfs,tableinfo_dir)
    for idx, df in enumerate(dfs):
        table_info = _get_tableinfo_with_index(idx,tableinfo_dir)
        print(f"Creating table: {table_info.table_name}")
        create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
    return engine,tables_infos



def create_retriever(engine,table_infos):
    # Create a SQLDatabase object
    sql_database = SQLDatabase(engine)
    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
        for t in table_infos
    ]  # add a SQLTableSchema for each table
    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    obj_retriever = obj_index.as_retriever(similarity_top_k=3)
    return obj_retriever,sql_database



def get_table_context_str(table_schema_objs: List[SQLTableSchema],sql_database):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Parse response to SQL."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        # TODO: move to removeprefix after Python 3.9+
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()




# 事件:找到了数据库中相关的表
class TableRetrieveEvent(Event):
    """Result of running table retrieval."""

    table_context_str: str
    query: str

# 事件:文本转为了 SQL
class TextToSQLEvent(Event):
    """Text-to-SQL event."""

    sql: str
    query: str


class TextToSQLWorkflow1(Workflow):
    """Text-to-SQL Workflow that does query-time table retrieval."""

    def __init__(
        self,
        obj_retriever,
        text2sql_prompt,
        sql_retriever,
        sql_database,
        response_synthesis_prompt,
        llm,
        *args,
        **kwargs
    ) -> None:
        """Init params."""
        super().__init__(*args, **kwargs)
        self.obj_retriever = obj_retriever
        self.text2sql_prompt = text2sql_prompt
        self.sql_retriever = sql_retriever
        self.sql_database = sql_database
        self.response_synthesis_prompt = response_synthesis_prompt
        self.llm = llm

    @step
    async def retrieve_tables(
        self, ctx: Context, ev: StartEvent
    ) -> TableRetrieveEvent:
        """Retrieve tables."""
        table_schema_objs = self.obj_retriever.retrieve(ev.query)
        table_context_str = get_table_context_str(table_schema_objs,self.sql_database)
        print("====\n"+table_context_str+"\n====")
        return TableRetrieveEvent(
            table_context_str=table_context_str, query=ev.query
        )

    @step
    async def generate_sql(
        self, ctx: Context, ev: TableRetrieveEvent
    ) -> TextToSQLEvent:
        """Generate SQL statement."""
        fmt_messages = self.text2sql_prompt.format_messages(
            query_str=ev.query, schema=ev.table_context_str
        )
        chat_response = self.llm.chat(fmt_messages)
        sql = parse_response_to_sql(chat_response)
        print("====\n"+sql+"\n====")
        return TextToSQLEvent(sql=sql, query=ev.query)

    @step
    async def generate_response(self, ctx: Context, ev: TextToSQLEvent) -> StopEvent:
        """Run SQL retrieval and generate response."""
        retrieved_rows = self.sql_retriever.retrieve(ev.sql)
        print("====\n"+str(retrieved_rows)+"\n====")
        fmt_messages = self.response_synthesis_prompt.format_messages(
            sql_query=ev.sql,
            context_str=str(retrieved_rows),
            query_str=ev.query,
        )
        chat_response = llm.chat(fmt_messages)
        return StopEvent(result=chat_response)

async def task(lock):
    engine, table_infos = create_table()
    obj_retriever, sql_database = create_retriever(engine, table_infos)
    sql_retriever = SQLRetriever(sql_database)

    text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
        dialect=engine.dialect.name
    )
    response_synthesis_prompt_str = (
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )
    response_synthesis_prompt = PromptTemplate(
        response_synthesis_prompt_str,
    )

    workflow = TextToSQLWorkflow1(
        obj_retriever,
        text2sql_prompt,
        sql_retriever,
        sql_database,
        response_synthesis_prompt,
        llm,
        verbose=True,
    )
    async with lock:

        response = await workflow.run(
            query="What was the year that The Notorious B.I.G was signed to Bad Boy?"
        )
        print(str(response))


async def main():
    lock = asyncio.Lock()
    tasks = [task(lock)]
    await asyncio.gather(*tasks)

if __name__ == "__main__":
    asyncio.run(main())
    from llama_index.utils.workflow import draw_all_possible_flows

    draw_all_possible_flows(
        test_to_sql_wf, filename="text_to_sql_table_retrieval.html"
    )