使用efficientnet模型为基底微调训练图片分类模型

93 阅读2分钟
import os
import time

import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

save_model_path = r'bestmodel.pth'
train_path = r'./dataset/training_data/scene'
val_path = r'./dataset/validation_data/scene'
test_path = r'./dataset/test_data/scene'
'''
超参数
'''
batch_size = 32
learning_rate = 1e-4
epoches = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
'''
训练数据集
'''
train_data_len = sum([len(x) for _, _, x in os.walk(train_path)])
val_data_len = sum([len(x) for _, _, x in os.walk(val_path)])
test_data_len = sum([len(x) for _, _, x in os.walk(test_path)])
train_transform = transforms.Compose([
    transforms.Resize([244, 244]),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transform = test_transform = transforms.Compose([
    transforms.Resize([244, 244]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_data = ImageFolder(train_path, transform=train_transform)
val_data = ImageFolder(val_path, transform=val_transform)
test_data = ImageFolder(test_path, transform=test_transform)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=False)
'''
神经网络模型
'''
model = models.efficientnet_b4(pretrained=True)
class_num = len(os.listdir(train_path))
model.fc = torch.nn.Linear(2048, class_num)
model = model.to(device)
loss_cross = torch.nn.CrossEntropyLoss()
optim_adam = torch.optim.Adam(model.parameters(), lr=learning_rate)


def train(net, optim, loss):
    net.train()
    total_loss = 0
    total_corrects = 0
    for i, (image_item, label_item) in enumerate(train_loader):
        image_item = Variable(image_item.to(device))
        label_item = Variable(label_item.to(device))
        optim.zero_grad()
        outputs = net(image_item)
        loss_obj = loss(outputs, label_item)
        loss_obj.backward()
        optim.step()
        total_loss += loss_obj.item()
        _, max_index = torch.max(outputs, 1)
        pred_label = max_index.cpu().numpy()
        true_label = label_item.cpu().numpy()
        total_corrects += np.sum(pred_label == true_label)
    return total_loss / float(len(train_loader)), total_corrects / train_data_len


def evaluate(net, loss):
    net.eval()
    total_corrects = total_eval_loss = 0
    with torch.no_grad():
        for image_item, label_item in val_loader:
            image_item = Variable(image_item.to(device))
            label_item = Variable(label_item.to(device))
            outputs = net(image_item)
            loss_obj = loss(outputs, label_item)
            total_eval_loss += loss_obj.item()
            _, max_index = torch.max(outputs, 1)
            pred_label = max_index.cpu().numpy()
            true_label = label_item.cpu().numpy()
            total_corrects += np.sum(pred_label == true_label)
    return total_eval_loss / float(len(val_loader)), total_corrects / val_data_len


def test():
    model.load_state_dict(torch.load(save_model_path, map_location=lambda storage, loc: storage),
                          strict=True)  # GPU:lambda storage, loc: storage.cuda(0)  CPU:lambda storage, loc: storage
    model.eval()
    total_corrects = 0
    start = time.time() * 1000
    with torch.no_grad():
        for i, (image, label) in enumerate(test_loader):
            image = Variable(image.to(device))
            label = Variable(label.to(device))
            pred = model(image)
            max_value, max_index = torch.max(pred, 1)
            pred_label = max_index.cpu().numpy()
            true_label = label.cpu().numpy()
            total_corrects += np.sum(pred_label == true_label)
    end = time.time() * 1000
    print(f'correct rate:{total_corrects / test_data_len}, run v: {int((end - start) / test_data_len)}')


def main():
    best_acc = 0
    for epoch_index in range(epoches):
        loss, train_acc = train(model, optim_adam, loss_cross)
        print(epoch_index + 1, loss, train_acc)
        loss, val_acc = evaluate(model, loss_cross)
        print(epoch_index + 1, loss, val_acc)

        if val_acc > best_acc:
            print(epoch_index + 1, val_acc)
            torch.save(model.state_dict(), save_model_path)
            best_acc = val_acc


if __name__ == '__main__':
    main()
    test()