PyTorch项目实战14——使用集成学习思想训练识别模型

399 阅读2分钟

在前边我们使用了 CNNNet 模型、带有全局平均池化的 CNNNet 模型、以及 LeNet 模型,其识别的准确率基本都在60%左右。

这次把 应用了全局平均池化的CNNNet、CNNNet 和 LeNet 集成起来,实现 1+1>2 的效果。

1 模型定义

将三个模型的定义,包括层级定义及层级的串联放在同一个python文件中。

image.png

2 模型训练

  • 在训练模型时,定义了多层级的神经网络,将上边三种训练模型加入到模型数组中;

  • 在随机剃度下降中,使用 Adam() 方法,并对模型数组进行遍历,获取每一个模型中的参数:

optimizer = optim.Adam([{"params": mlp.parameters()} for mlp in mlps], lr=0.001)

  • 在正向和反向传播中,遍历模型数组,并使用 mlp.train() 激活对当前模型的训练
# 定义多层级神经网络
net0 = Net()
net1 = CNNNet()
net2 = LeNet()
mlps = [net0.to(device), net1.to(device), net2.to(device)]

# 训练模型
# 获得交叉乘损失
criterion = nn.CrossEntropyLoss()
# 随机梯度下降
optimizer = optim.Adam([{"params": mlp.parameters()} for mlp in mlps], lr=0.001)

for epoch in range(10):
    # 遍历训练集
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        # 将输入数据放入显存中
        inputs, labels = inputs.to(device), labels.to(device)
        # 权重参数梯度清零
        optimizer.zero_grad()
        # 正向和反向传播
        for mlp in mlps:
            mlp.train()
            outputs = mlp(inputs)
            # 获取交叉损失
            loss = criterion(outputs, labels)
            # 获取梯度
            loss.backward()
        # 更新参数
        optimizer.step()

3 模型测试

因为是多学习模型,所以在测试中,引入投票机制,该机制需要学习模型的数量为奇数。用于将被测图片所属类别进行归类。

pre = []
# 投票
vote_correct = 0
mlps_correct = [0 for i in range(len(mlps))]
for img, label in testloader:
    img, label = img.to(device), label.to(device)
    for i, mlp in enumerate(mlps):
        mlp.eval()
        out = mlp(img)
        # 按行取最大值
        _, prediction = torch.max(out, 1)
        pre_num = prediction.cpu().numpy()
        mlps_correct[i] += (pre_num == label.cpu().numpy()).sum()
        pre.append(pre_num)
    arr = np.array(pre)
    pre.clear()
    result = [Counter(arr[:, i]).most_common(1)[0][0] for i in range(4)]
    vote_correct += (result == label.cpu().numpy()).sum()
print("epoch: " + str(epoch) + " 集成学习模型的正确率 " + str(vote_correct/len(testloader)))

for idx, correct in enumerate(mlps_correct):
    print("模型 " + str(idx) + " 的正确率为: " + str(correct/len(testloader)))

4 运行程序

迭代10次后的输出如下:

image.png

集成学习模型的正确率和各模型自身的正确率,随着迭代的进行,都在提高。

随着迭代的继续深入,集成学习模型的正确率,将会是最好的。