预测 TopK
def predict_ch3_topK(net, test_iter, k=3):
for X, y in test_iter:
break
res_X = net(X)
for true, pred in zip(y, res_X):
true_label = d2l.get_fashion_mnist_labels([true])
top_k_value, top_k_indices = torch.topk(pred, k=k)
pred_label = d2l.get_fashion_mnist_labels(top_k_indices)
if top_k_indices[0].item() == true:
continue
print("真实值:" + str(true_label), end="\t")
print("预测值:" + str(pred_label), str(top_k_value.data))
将 PyTorch 模型转换为 ONNX 格式
import torch
from torch import nn
num_inputs = 784
num_outputs = 10
num_hiddens = 256
if __name__ == '__main__':
net = nn.Sequential(nn.Flatten(),
nn.Linear(num_inputs, num_hiddens),
nn.ReLU(),
nn.Linear(num_hiddens, num_outputs))
model_path = "mlp_nn.pth"
net.load_state_dict(torch.load(model_path))
net.eval()
input_shape = (256, 784)
x = torch.randn(input_shape)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_file = "mlp_nn.onnx"
torch.onnx.export(net, x.to(device), output_file, export_params=True)