Sagemaker Jumpstart部署/测试Whisper

48 阅读1分钟

创建sagemaker domain&profile

image.png

直接set up即可 image.png

创建profile image.png

image.png

点击下一步直到profile创建成功。

通过Jumpstart创建推理endpoint

进入studio界面

image.png

进入jumpstart界面

image.png provider搜索hugging face

image.png models搜索whisper,点击whisper Large V3 image.png

点击deploy,即可创建推理Endpoint

image.png image.png

调用推理endpoint进行推理

image.png

在该界面中,我们可以看到通过Sagemaker SDK调用推理Endpoint的示例代码

import json
import boto3
from sagemaker.jumpstart import utils
# The wav files must be sampled at 16kHz (this is required by the automatic speech recognition models), so make sure to resample them if required. The input audio file must be less than 30 seconds.
s3_bucket = utils.get_jumpstart_content_bucket(boto3.Session().region_name)
key_prefix = "training-datasets/asr_notebook_data"
input_audio_file_name = "sample1.wav"

s3_client = boto3.client("s3")
s3_client.download_file(s3_bucket, f"{key_prefix}/{input_audio_file_name }", input_audio_file_name )

with open(input_audio_file_name, "rb") as file:
    wav_file_read = file.read()
endpoint_name = 'jumpstart-dft-hf-asr-whisper-large-20250808-035351'

def query_endpoint(body, content_type):
    client = boto3.client('runtime.sagemaker')
    response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType=content_type, Body=body)
    model_predictions = json.loads(response['Body'].read())
    text = model_predictions['text']
    print(f"Text: {text}")

# If you receive client error (413) please check the payload size to the endpoint. Payloads for SageMaker invoke endpoint requests are limited to about 5MB
query_endpoint(wav_file_read, "audio/wav")
# The file must be sampled at 16kHz (this is required by the automatic speech recognition models), so make sure to resample them if required. Also, the input audio file must be less than 30 seconds.
input_audio_file_name = "sample_french1.wav"

s3_client.download_file(s3_bucket, f"{key_prefix}/{input_audio_file_name }", input_audio_file_name )

with open(input_audio_file_name, "rb") as file:
    wav_file_read = file.read()

payload = {"audio_input": wav_file_read.hex(),
           "language": "french",
           "task": "translate"}

#If you receive client error (413) please check the payload  size. Payloads for SageMaker invoke endpoint requests are limited to about 5MB
query_endpoint(json.dumps(payload).encode('utf-8'), "application/json")

以上代码:

  • 通过boto3 S3 SDK下载示例音频
  • 通过Sagemaker SDK调用推理Endpoint,获取推理结果

测试结果 image.png