高效利用AWS SageMaker加载自定义模型:深入解析与实用示例

56 阅读2分钟

引言

在AI模型部署领域,AWS SageMaker以其灵活性和强大功能吸引了众多开发者。本文旨在介绍如何利用SageMaker部署自定义模型,并通过Python代码示例展示与模型通信的具体实现。特别是,我们将讨论如何使用SagemakerEndpointEmbeddings类,这对于那些在SageMaker上托管自己模型的开发者非常有用。

主要内容

1. SageMaker与自定义模型

AWS SageMaker提供了大规模部署机器学习模型的便捷途径。通过SageMaker,你可以轻松将本地开发的模型托管在云端,并利用其强大的计算资源进行推理。

2. 处理批量请求

为了提升吞吐量,处理批量请求是不可或缺的一环。在自定义的inference.py脚本中,调整predict_fn()函数的返回值,可以有效处理批量数据:

# 从
return {"vectors": sentence_embeddings[0].tolist()}
# 修改为
return {"vectors": sentence_embeddings.tolist()}

3. 环境准备

在开始编写代码前,确保安装必要的Python包:

!pip3 install langchain boto3

代码示例

下面我们将展示如何将自定义模型部署到SageMaker,并进行推理。

import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["vectors"]

content_handler = ContentHandler()

embeddings = SagemakerEndpointEmbeddings(
    endpoint_name="huggingface-pytorch-inference-2023-03-21-16-14-03-834",
    region_name="us-east-1",
    content_handler=content_handler,
)

# 使用API代理服务提高访问稳定性
query_result = embeddings.embed_query("foo")
doc_results = embeddings.embed_documents(["foo"])
print(doc_results)

常见问题和解决方案

问题1:网络访问不稳定

在某些地区,由于网络限制,可能需要配置API代理服务来稳定访问。

问题2:数据格式不匹配

确保请求和响应的数据格式与SageMaker上的inference.py保持一致。

总结和进一步学习资源

AWS SageMaker为模型的托管和推理提供了强大的支持。通过有效利用Python SDK,开发者可以轻松实现自定义模型的部署。为了深入学习,建议参考官方文档和相关教程:

参考资料

  • AWS SageMaker官方文档
  • Langchain库使用示例

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---