[深入探索Amazon SageMaker Experiments:如何追踪和记录ML实验]

53 阅读3分钟
# 引言

Amazon SageMaker 是一个全面管理的服务,帮助开发者快速简便地构建、训练和部署机器学习模型。SageMaker Experiments 是其一项功能,专为组织、跟踪、比较和评估机器学习实验及模型版本设计。在这篇文章中,我们将探讨如何利用 LangChain Callback 来将提示和其他LLM超参数记录到 SageMaker Experiments 中,并演示不同的使用场景。

# 主要内容

## 安装和设置

首先,我们安装所需的库:

```bash
%pip install --upgrade --quiet sagemaker
%pip install --upgrade --quiet langchain-openai
%pip install --upgrade --quiet google-search-results

接下来,我们需要设置 API 密钥:

import os

# 添加你的API密钥
os.environ["OPENAI_API_KEY"] = "<ADD-KEY-HERE>"
os.environ["SERPAPI_API_KEY"] = "<ADD-KEY-HERE>"

LLM 提示跟踪

设置超参数和实验

我们为LLM模型设置了一些基本的超参数,并创建一个SageMaker会话:

from sagemaker.session import Session

HPARAMS = {
    "temperature": 0.1,
    "model_name": "gpt-3.5-turbo-instruct",
}

BUCKET_NAME = None
EXPERIMENT_NAME = "langchain-sagemaker-tracker"
session = Session(default_bucket=BUCKET_NAME)

场景演示

场景 1 - 单一LLM模型

这是一个简单的场景,使用一个LLM模型根据给定的提示生成输出。以下代码展示了如何在 SageMaker 中记录这个过程:

from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
from sagemaker.experiments.run import Run
from langchain.chains import LLMChain

RUN_NAME = "run-scenario-1"
PROMPT_TEMPLATE = "tell me a joke about {topic}"
INPUT_VARIABLES = {"topic": "fish"}

with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)
    chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])
    chain.run(**INPUT_VARIABLES)
    sagemaker_callback.flush_tracker()
场景 2 - 顺序链

在这个场景下,我们使用了一个顺序链,包括两个LLM模型:

from langchain.chains import SimpleSequentialChain

RUN_NAME = "run-scenario-2"

PROMPT_TEMPLATE_1 = "You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\nTitle: {title}\nPlaywright: This is a synopsis for the above play:"
PROMPT_TEMPLATE_2 = "You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\nPlay Synopsis: {synopsis}\nReview from a New York Times play critic of the above play:"

INPUT_VARIABLES = {
    "input": "documentary about good video games that push the boundary of game design"
}

with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:
    sagemaker_callback = SageMakerCallbackHandler(run)
    prompt_template1 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_1)
    prompt_template2 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_2)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])
    chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])
    overall_chain = SimpleSequentialChain(chains=[chain1, chain2], callbacks=[sagemaker_callback])
    overall_chain.run(**INPUT_VARIABLES)
    sagemaker_callback.flush_tracker()
场景 3 - 代理与工具

最后一个场景中,我们引入了多种工具,与LLM结合实现复杂的任务:

from langchain.agents import initialize_agent, load_tools

RUN_NAME = "run-scenario-3"
PROMPT_TEMPLATE = "Who is the oldest person alive? And what is their current age raised to the power of 1.51?"

with Run(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session) as run:
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    tools = load_tools(["serpapi", "llm-math"], llm=llm, callbacks=[sagemaker_callback])
    agent = initialize_agent(tools, llm, agent="zero-shot-react-description", callbacks=[sagemaker_callback])
    agent.run(input=PROMPT_TEMPLATE)
    sagemaker_callback.flush_tracker()

加载日志数据

这些实验记录可以通过将其加载到Pandas DataFrame中进行查看:

from sagemaker.analytics import ExperimentAnalytics

logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)
df = logs.dataframe(force_refresh=True)

print(df.shape)
df.head()

日志数据可以帮助我们深入分析每个实验运行过程中的LLM设置和超参数,如json格式保存在S3中。

常见问题和解决方案

Q1: 为什么我的OpenAI API请求不成功?

A1: 检查API密钥是否正确且有效。此外,由于某些地区的网络限制,开发者可能需要考虑使用API代理服务来提高访问稳定性。

Q2: SageMaker Experiments记录的数据能否导出到其他平台?

A2: 可以通过Python客户端将数据从S3下载并转换为Pandas DataFrame,再进一步导出到CSV等格式。

总结和进一步学习资源

借助LangChain和SageMaker Experiments的结合,开发者可以轻松地组织和分析机器学习实验,尤其是涉及生成模型的复杂任务。为了更深入学习,以下资源可以提供更多帮助:

  1. Amazon SageMaker Experiments Documentation
  2. LangChain Documentation
  3. OpenAI API Documentation

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---