WorkFlow的基本使用
WorkFlow,也就是工作流,作为程序员对这个东西并不陌生,因为写代码的过程中有很大一部分就是处理各种流程,对各种流程进行抽象。
假设现在有一个需求,整体流程是这样的
任务开始执行事件A,然后并发执行事件B、C,等事件B、C全部执行完毕后,执行事件D,然后流程结束,并不复杂对不对?
但是在你绞尽脑汁刚刚写好代码,处理好并发问题,测试完毕,产品修改了需求,事件B、C,不并发执行了,改成顺序执行。。。
在LlamaIndex中这种修改非常方便,看下面的代码
这是初版需求的工作流
from llama_index.core.workflow import Workflow, Event, StartEvent, StopEvent, Context, step, draw_all_possible_flows
class EventA(Event):
pass
class EventB(Event):
pass
class EventC(Event):
pass
class EventD(Event):
pass
class WorkFlowTest(Workflow):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@step
def step_a(self,ctx:Context,ev:StartEvent) -> EventA:
return EventA()
@step
def step_b(self,ctx:Context,ev:EventA) -> EventB:
return EventB()
@step
def step_c(self, ctx: Context, ev: EventA) -> EventC:
return EventC()
@step
def step_d(self, ctx: Context, ev: EventB|EventC) -> StopEvent:
return StopEvent()
这是修改之后的,仅仅改动了一下参数类型和返回值
from llama_index.core.workflow import Workflow, Event, StartEvent, StopEvent, Context, step, draw_all_possible_flows
class EventA(Event):
pass
class EventB(Event):
pass
class EventC(Event):
pass
class EventD(Event):
pass
class WorkFlowTest(Workflow):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@step
def step_a(self,ctx:Context,ev:StartEvent) -> EventA:
return EventA()
@step
def step_b(self,ctx:Context,ev:EventA) -> EventB:
return EventB()
@step
def step_c(self, ctx: Context, ev: EventB) -> EventC:
return EventC()
@step
def step_d(self, ctx: Context, ev: EventC) -> StopEvent:
return StopEvent()
是不是非常方便!
在LlamaIndex中,工作流是由step组成的,从StartEvent开始,每个step会处理特定的Event,处理完毕后触发下一个事件交给下一个step处理,直到产生StopEvent
当然在LlamaIndex的工作流中不止能处理这种简单的顺序执行的场景,还支持循环,并发等多种情况
循环工作流
循环其实也非常简单,只需要定义一个Loop事件就可以了,根据官方文档,事件可以拥有任何自定的名称
import asyncio
from llama_index.core.workflow import Workflow, Event, StartEvent, StopEvent, Context, step, draw_all_possible_flows
class EventA(Event):
msg:str
class EventB(Event):
msg:str
class EventC(Event):
msg:str
class LoopEvent(Event):
msg:str
class WorkFlowTest(Workflow):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.count = 5
self.loop_num = 1
@step
async def step_a(self,ctx:Context,ev:StartEvent) -> EventA:
print(ev.msg)
print('我是步骤a')
return EventA(msg='事件A')
@step
async def step_b(self,ctx:Context,ev:EventA|LoopEvent) -> EventB|LoopEvent:
print(f"我是步骤b,由{ev.msg}触发,循环次数{self.count} ,当前第{self.loop_num}次")
self.loop_num +=1
if self.loop_num <= self.count:
return LoopEvent(msg="循环事件")
return EventB(msg="事件b")
@step
async def step_c(self, ctx: Context, ev: EventB) -> EventC:
print(f"我是步骤c,由{ev.msg}触发")
return EventC(msg="事件c")
@step
async def step_d(self, ctx: Context, ev: EventC) -> StopEvent:
print(f"我是步骤d,由{ev.msg}触发")
return StopEvent(result="流程结束")
workflow = WorkFlowTest()
async def task(lock):
async with lock:
response = await workflow.run(
msg="流程开始"
)
print(str(response))
async def main():
lock = asyncio.Lock()
tasks = [task(lock)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
asyncio.run(main())
Context
在提并发工作流之前,先说明一下步骤和步骤之间是怎么传递数据的,在上面的代码中,可以注意到有一个Context参数,这个参数就是用来传递数据用的,这个有点类似于Go语言里的Context,可以用于上下文之间的数据传递,而不是依靠Event,将数据在所有步骤间到处传递
class SetupEvent(Event):
query: str
class StepTwoEvent(Event):
query: str
class StatefulFlow(Workflow):
@step
async def start(
self, ctx: Context, ev: StartEvent
) -> SetupEvent | StepTwoEvent:
db = await ctx.get("some_database", default=None)
if db is None:
print("Need to load data")
return SetupEvent(query=ev.query)
# do something with the query
return StepTwoEvent(query=ev.query)
@step
async def setup(self, ctx: Context, ev: SetupEvent) -> StartEvent:
# load data
await ctx.set("some_database", [1, 2, 3])
return StartEvent(query=ev.query)
并发工作流
并发工作流,也需要这个Context,具体可以看代码
import asyncio
from llama_index.core.workflow import Workflow, Event, StartEvent, StopEvent, Context, step, draw_all_possible_flows
class EventA(Event):
msg:str
class EventB(Event):
msg:str
class EventC(Event):
msg:str
class LoopEvent(Event):
msg:str
class WorkFlowTest(Workflow):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.count = 5
self.loop_num = 1
@step
async def step_a(self,ctx:Context,ev:StartEvent) -> EventA:
print(ev.msg)
print('我是步骤a')
return EventA(msg='事件A')
@step
async def step_b(self,ctx:Context,ev:EventA) -> EventB|EventC:
ctx.send_event(EventB(msg='事件b'))
ctx.send_event(EventC(msg='事件b'))
@step
async def step_c(self, ctx: Context, ev: EventB) -> StopEvent:
print(f"我是步骤c,由{ev.msg}触发,我和d同时执行")
return StopEvent(result="流程结束")
@step
async def step_d(self, ctx: Context, ev: EventC) -> StopEvent:
print(f"我是步骤d,由{ev.msg}触发,我和c同时执行")
return StopEvent(result="流程结束")
workflow = WorkFlowTest()
async def task(lock):
async with lock:
response = await workflow.run(
msg="流程开始"
)
print(str(response))
async def main():
lock = asyncio.Lock()
tasks = [task(lock)]
res = await asyncio.gather(*tasks)
if __name__ == "__main__":
asyncio.run(main())
关于工作流还有更多的内容,具体可以参考官方文档
案例:使用自然语言进行数据库查询
需求说明
- 用户输入自然语言查询
- 系统先去检索跟查询相关的表
- 根据表的 Schema 让大模型生成 SQL
- 用生成的 SQL 查询数据库
- 根据查询结果,调用大模型生成自然语言回复
数据准备
# 下载 WikiTableQuestions
# WikiTableQuestions 是一个为表格问答设计的数据集。其中包含 2,108 个从维基百科提取的 HTML 表格
wget "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip" -O wiki_data.zip
unzip wiki_data.zip
数据入库
遍历目录加载表格
import pandas as pd
from pathlib import Path
data_dir = Path("./WikiTableQuestions/csv/200-csv")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []
for csv_file in csv_files:
print(f"processing file: {csv_file}")
try:
df = pd.read_csv(csv_file)
dfs.append(df)
except Exception as e:
print(f"Error parsing {csv_file}: {str(e)}")
为每个表生成一段文字描述,保存到 WikiTableQuestions_TableInfo 目录
import os
import json
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage
tableinfo_dir = "WikiTableQuestions_TableInfo"
if not os.path.exists(tableinfo_dir):
os.mkdir(tableinfo_dir)
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(
..., description="table name (must be underscores and NO spaces)"
)
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
prompt_str = """\
Give me a summary of the table with the following JSON format.
- The table name must be unique to the table and describe it while being concise.
- Do NOT output a generic table name (e.g. table, my_table).
Do NOT make the table name one of the following: {exclude_table_name_list}
Table:
{table_str}
Summary: """
prompt_tmpl = ChatPromptTemplate(
message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)
llm = OpenAI(model="gpt-4o-mini")
def _get_tableinfo_with_index(idx: int) -> str:
results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
results_list = list(results_gen)
if len(results_list) == 0:
return None
elif len(results_list) == 1:
path = results_list[0]
with open(path, 'r') as file:
data = json.load(file)
return TableInfo.model_validate(data)
else:
raise ValueError(
f"More than one file matching index: {list(results_gen)}"
)
table_names = set()
table_infos = []
for idx, df in enumerate(dfs):
table_info = _get_tableinfo_with_index(idx)
if table_info:
table_infos.append(table_info)
else:
while True:
df_str = df.head(10).to_csv()
table_info = llm.structured_predict(
TableInfo,
prompt_tmpl,
table_str=df_str,
exclude_table_name_list=str(list(table_names)),
)
table_name = table_info.table_name
print(f"Processed table: {table_name}")
if table_name not in table_names:
table_names.add(table_name)
break
else:
# try again
print(f"Table name {table_name} already exists, trying again.")
pass
out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
json.dump(table_info.dict(), open(out_file, "w"))
table_infos.append(table_info)
将上面的表格写入sqllite
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
)
import re
# Function to create a sanitized column name
def sanitize_column_name(col_name):
# Remove special characters and replace spaces with underscores
return re.sub(r"\W+", "_", col_name)
# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
df: pd.DataFrame, table_name: str, engine, metadata_obj
):
# Sanitize column names
sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
df = df.rename(columns=sanitized_columns)
# Dynamically create columns based on DataFrame columns and data types
columns = [
Column(col, String if dtype == "object" else Integer)
for col, dtype in zip(df.columns, df.dtypes)
]
# Create a table with the defined columns
table = Table(table_name, metadata_obj, *columns)
# Create the table in the database
metadata_obj.create_all(engine)
# Insert data from DataFrame into the table
with engine.connect() as conn:
for _, row in df.iterrows():
insert_stmt = table.insert().values(**row.to_dict())
conn.execute(insert_stmt)
conn.commit()
# engine = create_engine("sqlite:///:memory:")
engine = create_engine("sqlite:///wiki_table_questions.db")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
tableinfo = _get_tableinfo_with_index(idx)
print(f"Creating table: {tableinfo.table_name}")
create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)
基础工具构建
构建向量索引
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from llama_index.core import SQLDatabase, VectorStoreIndex
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
] # add a SQLTableSchema for each table
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
创建sql查询器
from llama_index.core.retrievers import SQLRetriever
from typing import List
sql_retriever = SQLRetriever(sql_database)
def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
"""Get table context string."""
context_strs = []
for table_schema_obj in table_schema_objs:
table_info = sql_database.get_single_table_info(
table_schema_obj.table_name
)
if table_schema_obj.context_str:
table_opt_context = " The table description is: "
table_opt_context += table_schema_obj.context_str
table_info += table_opt_context
context_strs.append(table_info)
return "\n\n".join(context_strs)
创建Text2SQL提示词(系统默认模板)和输出结果解析器(从生成的文本中抽取SQL)
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatResponse
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Parse response to SQL."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
# TODO: move to removeprefix after Python 3.9+
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
dialect=engine.dialect.name
)
print(text2sql_prompt.template)
创建自然语言回复模板
response_synthesis_prompt_str = (
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
response_synthesis_prompt = PromptTemplate(
response_synthesis_prompt_str,
)
定义工作流
from llama_index.core.workflow import (
Workflow,
StartEvent,
StopEvent,
step,
Context,
Event,
)
# 事件:找到了数据库中相关的表
class TableRetrieveEvent(Event):
"""Result of running table retrieval."""
table_context_str: str
query: str
# 事件:文本转为了 SQL
class TextToSQLEvent(Event):
"""Text-to-SQL event."""
sql: str
query: str
class TextToSQLWorkflow1(Workflow):
"""Text-to-SQL Workflow that does query-time table retrieval."""
def __init__(
self,
obj_retriever,
text2sql_prompt,
sql_retriever,
response_synthesis_prompt,
llm,
*args,
**kwargs
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self.obj_retriever = obj_retriever
self.text2sql_prompt = text2sql_prompt
self.sql_retriever = sql_retriever
self.response_synthesis_prompt = response_synthesis_prompt
self.llm = llm
@step
def retrieve_tables(
self, ctx: Context, ev: StartEvent
) -> TableRetrieveEvent:
"""Retrieve tables."""
table_schema_objs = self.obj_retriever.retrieve(ev.query)
table_context_str = get_table_context_str(table_schema_objs)
print("====\n"+table_context_str+"\n====")
return TableRetrieveEvent(
table_context_str=table_context_str, query=ev.query
)
@step
def generate_sql(
self, ctx: Context, ev: TableRetrieveEvent
) -> TextToSQLEvent:
"""Generate SQL statement."""
fmt_messages = self.text2sql_prompt.format_messages(
query_str=ev.query, schema=ev.table_context_str
)
chat_response = self.llm.chat(fmt_messages)
sql = parse_response_to_sql(chat_response)
print("====\n"+sql+"\n====")
return TextToSQLEvent(sql=sql, query=ev.query)
@step
def generate_response(self, ctx: Context, ev: TextToSQLEvent) -> StopEvent:
"""Run SQL retrieval and generate response."""
retrieved_rows = self.sql_retriever.retrieve(ev.sql)
print("====\n"+str(retrieved_rows)+"\n====")
fmt_messages = self.response_synthesis_prompt.format_messages(
sql_query=ev.sql,
context_str=str(retrieved_rows),
query_str=ev.query,
)
chat_response = llm.chat(fmt_messages)
return StopEvent(result=chat_response)
运行
async def task(lock):
engine, table_infos = create_table()
obj_retriever, sql_database = create_retriever(engine, table_infos)
sql_retriever = SQLRetriever(sql_database)
text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
dialect=engine.dialect.name
)
response_synthesis_prompt_str = (
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
response_synthesis_prompt = PromptTemplate(
response_synthesis_prompt_str,
)
workflow = TextToSQLWorkflow1(
obj_retriever,
text2sql_prompt,
sql_retriever,
sql_database,
response_synthesis_prompt,
llm,
verbose=True,
)
async with lock:
response = await workflow.run(
query="What was the year that The Notorious B.I.G was signed to Bad Boy?"
)
print(str(response))
async def main():
lock = asyncio.Lock()
tasks = [task(lock)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
asyncio.run(main())
可视化工作流
from llama_index.utils.workflow import draw_all_possible_flows
draw_all_possible_flows(
workflow, filename="text_to_sql_table_retrieval.html"
)
完整代码
import asyncio
import json
import os
import re
import traceback
from pathlib import Path
from typing import List
import nest_asyncio
import openai
import pandas as pd
from llama_index.core import PromptTemplate
from llama_index.core import SQLDatabase, VectorStoreIndex
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.llms import ChatMessage
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.workflow import (
Workflow,
StartEvent,
StopEvent,
step,
Context,
Event,
)
from llama_index.llms.openai import OpenAI
# put data into sqlite db
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
)
openai.base_url = 'your api url'
openai.api_key = 'your api key'
prompt_str = """\
Give me a summary of the table with the following JSON format.
- The table name must be unique to the table and describe it while being concise.
- Do NOT output a generic table name (e.g. table, my_table).
Do NOT make the table name one of the following: {exclude_table_name_list}
Table:
{table_str}
Summary: """
prompt_tmpl = ChatPromptTemplate(
message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)
llm = OpenAI(model="gpt-4o-mini",temperature=0)
def process_file():
data_dir = Path('./WikiTableQuestions/csv/200-csv')
csv_file = sorted([i for i in data_dir.glob('*.csv')])
dfs = []
for csv_file in csv_file:
print(f'正在处理文件:{csv_file}')
try:
df = pd.read_csv(csv_file)
dfs.append(df)
except Exception as e:
print(f'处理文件:{csv_file} 失败:{traceback.format_exc()}')
return dfs
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(
..., description="table name (must be underscores and NO spaces)"
)
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
def _get_tableinfo_with_index(idx: int,tableinfo_dir) -> TableInfo:
results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
results_list = list(results_gen)
if len(results_list) == 0:
return None
elif len(results_list) == 1:
path = results_list[0]
with open(path, 'r') as file:
data = json.load(file)
return TableInfo.model_validate(data)
else:
raise ValueError(
f"More than one file matching index: {list(results_gen)}"
)
# Function to create a sanitized column name
def sanitize_column_name(col_name):
# Remove special characters and replace spaces with underscores
return re.sub(r"\W+", "_", col_name)
# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
df: pd.DataFrame, table_name: str, engine, metadata_obj
):
# Sanitize column names
sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
df = df.rename(columns=sanitized_columns)
# Dynamically create columns based on DataFrame columns and data types
columns = [
Column(col, String if dtype == "object" else Integer)
for col, dtype in zip(df.columns, df.dtypes)
]
# Create a table with the defined columns
table = Table(table_name, metadata_obj, *columns)
# Create the table in the database
metadata_obj.create_all(engine)
# Insert data from DataFrame into the table
with engine.connect() as conn:
for _, row in df.iterrows():
insert_stmt = table.insert().values(**row.to_dict())
conn.execute(insert_stmt)
conn.commit()
def write_table_info_dir(dfs,tableinfo_dir):
table_names = set()
table_infos = []
for idx, df in enumerate(dfs):
table_info = _get_tableinfo_with_index(idx,tableinfo_dir)
if table_info:
table_infos.append(table_info)
else:
while True:
df_str = df.head(10).to_csv()
table_info = llm.structured_predict(
TableInfo,
prompt_tmpl,
table_str=df_str,
exclude_table_name_list=str(list(table_names)),
)
table_name = table_info.table_name
print(f"Processed table: {table_name}")
if table_name not in table_names:
table_names.add(table_name)
break
else:
# try again
print(f"Table name {table_name} already exists, trying again.")
pass
out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
json.dump(table_info.dict(), open(out_file, "w"))
table_infos.append(table_info)
return table_infos
# engine = create_engine("sqlite:///:memory:")
def create_table():
tableinfo_dir = "WikiTableQuestions_TableInfo"
os.makedirs(tableinfo_dir, exist_ok=True)
if os.path.exists("wiki_table_questions.db"):
os.remove("wiki_table_questions.db")
engine = create_engine("sqlite:///wiki_table_questions.db",echo=True)
metadata_obj = MetaData()
dfs = process_file()
tables_infos = write_table_info_dir(dfs,tableinfo_dir)
for idx, df in enumerate(dfs):
table_info = _get_tableinfo_with_index(idx,tableinfo_dir)
print(f"Creating table: {table_info.table_name}")
create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
return engine,tables_infos
def create_retriever(engine,table_infos):
# Create a SQLDatabase object
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
] # add a SQLTableSchema for each table
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
return obj_retriever,sql_database
def get_table_context_str(table_schema_objs: List[SQLTableSchema],sql_database):
"""Get table context string."""
context_strs = []
for table_schema_obj in table_schema_objs:
table_info = sql_database.get_single_table_info(
table_schema_obj.table_name
)
if table_schema_obj.context_str:
table_opt_context = " The table description is: "
table_opt_context += table_schema_obj.context_str
table_info += table_opt_context
context_strs.append(table_info)
return "\n\n".join(context_strs)
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Parse response to SQL."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
# TODO: move to removeprefix after Python 3.9+
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
# 事件:找到了数据库中相关的表
class TableRetrieveEvent(Event):
"""Result of running table retrieval."""
table_context_str: str
query: str
# 事件:文本转为了 SQL
class TextToSQLEvent(Event):
"""Text-to-SQL event."""
sql: str
query: str
class TextToSQLWorkflow1(Workflow):
"""Text-to-SQL Workflow that does query-time table retrieval."""
def __init__(
self,
obj_retriever,
text2sql_prompt,
sql_retriever,
sql_database,
response_synthesis_prompt,
llm,
*args,
**kwargs
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self.obj_retriever = obj_retriever
self.text2sql_prompt = text2sql_prompt
self.sql_retriever = sql_retriever
self.sql_database = sql_database
self.response_synthesis_prompt = response_synthesis_prompt
self.llm = llm
@step
async def retrieve_tables(
self, ctx: Context, ev: StartEvent
) -> TableRetrieveEvent:
"""Retrieve tables."""
table_schema_objs = self.obj_retriever.retrieve(ev.query)
table_context_str = get_table_context_str(table_schema_objs,self.sql_database)
print("====\n"+table_context_str+"\n====")
return TableRetrieveEvent(
table_context_str=table_context_str, query=ev.query
)
@step
async def generate_sql(
self, ctx: Context, ev: TableRetrieveEvent
) -> TextToSQLEvent:
"""Generate SQL statement."""
fmt_messages = self.text2sql_prompt.format_messages(
query_str=ev.query, schema=ev.table_context_str
)
chat_response = self.llm.chat(fmt_messages)
sql = parse_response_to_sql(chat_response)
print("====\n"+sql+"\n====")
return TextToSQLEvent(sql=sql, query=ev.query)
@step
async def generate_response(self, ctx: Context, ev: TextToSQLEvent) -> StopEvent:
"""Run SQL retrieval and generate response."""
retrieved_rows = self.sql_retriever.retrieve(ev.sql)
print("====\n"+str(retrieved_rows)+"\n====")
fmt_messages = self.response_synthesis_prompt.format_messages(
sql_query=ev.sql,
context_str=str(retrieved_rows),
query_str=ev.query,
)
chat_response = llm.chat(fmt_messages)
return StopEvent(result=chat_response)
async def task(lock):
engine, table_infos = create_table()
obj_retriever, sql_database = create_retriever(engine, table_infos)
sql_retriever = SQLRetriever(sql_database)
text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
dialect=engine.dialect.name
)
response_synthesis_prompt_str = (
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
response_synthesis_prompt = PromptTemplate(
response_synthesis_prompt_str,
)
workflow = TextToSQLWorkflow1(
obj_retriever,
text2sql_prompt,
sql_retriever,
sql_database,
response_synthesis_prompt,
llm,
verbose=True,
)
async with lock:
response = await workflow.run(
query="What was the year that The Notorious B.I.G was signed to Bad Boy?"
)
print(str(response))
async def main():
lock = asyncio.Lock()
tasks = [task(lock)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
asyncio.run(main())
from llama_index.utils.workflow import draw_all_possible_flows
draw_all_possible_flows(
test_to_sql_wf, filename="text_to_sql_table_retrieval.html"
)