李沐的深度学习课 课程笔记3 to ONNX

11 阅读1分钟

预测 TopK

def predict_ch3_topK(net, test_iter, k=3):  # @save
    for X, y in test_iter:
        break
    res_X = net(X)
    for true, pred in zip(y, res_X):
        true_label = d2l.get_fashion_mnist_labels([true])
        top_k_value, top_k_indices = torch.topk(pred, k=k)
        pred_label = d2l.get_fashion_mnist_labels(top_k_indices)
        if top_k_indices[0].item() == true:
            continue
        print("真实值:" + str(true_label), end="\t")
        print("预测值:" + str(pred_label), str(top_k_value.data))

将 PyTorch 模型转换为 ONNX 格式

import torch
from torch import nn

num_inputs = 784
num_outputs = 10
num_hiddens = 256

if __name__ == '__main__':
    net = nn.Sequential(nn.Flatten(),
                        nn.Linear(num_inputs, num_hiddens),
                        nn.ReLU(),
                        nn.Linear(num_hiddens, num_outputs))
    model_path = "mlp_nn.pth"
    net.load_state_dict(torch.load(model_path))
    net.eval()

    ########################################
    # pip install onnx
    # pip install onnxruntime
    # 定义输入张量,需要与模型的输入张量形状相同
    input_shape = (256, 784)  # num_hiddens num_inputs
    x = torch.randn(input_shape)
    # 需要指定输入张量,输出文件路径和运行设备
    # 默认情况下,输出张量的名称将基于模型中的名称自动分配
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 将 PyTorch 模型转换为 ONNX 格式
    output_file = "mlp_nn.onnx"
    torch.onnx.export(net, x.to(device), output_file, export_params=True)