import os
import time
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
save_model_path = r'bestmodel.pth'
train_path = r'./dataset/training_data/scene'
val_path = r'./dataset/validation_data/scene'
test_path = r'./dataset/test_data/scene'
'''
超参数
'''
batch_size = 32
learning_rate = 1e-4
epoches = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
'''
训练数据集
'''
train_data_len = sum([len(x) for _, _, x in os.walk(train_path)])
val_data_len = sum([len(x) for _, _, x in os.walk(val_path)])
test_data_len = sum([len(x) for _, _, x in os.walk(test_path)])
train_transform = transforms.Compose([
transforms.Resize([244, 244]),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transform = test_transform = transforms.Compose([
transforms.Resize([244, 244]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_data = ImageFolder(train_path, transform=train_transform)
val_data = ImageFolder(val_path, transform=val_transform)
test_data = ImageFolder(test_path, transform=test_transform)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=False)
'''
神经网络模型
'''
model = models.efficientnet_b4(pretrained=True)
class_num = len(os.listdir(train_path))
model.fc = torch.nn.Linear(2048, class_num)
model = model.to(device)
loss_cross = torch.nn.CrossEntropyLoss()
optim_adam = torch.optim.Adam(model.parameters(), lr=learning_rate)
def train(net, optim, loss):
net.train()
total_loss = 0
total_corrects = 0
for i, (image_item, label_item) in enumerate(train_loader):
image_item = Variable(image_item.to(device))
label_item = Variable(label_item.to(device))
optim.zero_grad()
outputs = net(image_item)
loss_obj = loss(outputs, label_item)
loss_obj.backward()
optim.step()
total_loss += loss_obj.item()
_, max_index = torch.max(outputs, 1)
pred_label = max_index.cpu().numpy()
true_label = label_item.cpu().numpy()
total_corrects += np.sum(pred_label == true_label)
return total_loss / float(len(train_loader)), total_corrects / train_data_len
def evaluate(net, loss):
net.eval()
total_corrects = total_eval_loss = 0
with torch.no_grad():
for image_item, label_item in val_loader:
image_item = Variable(image_item.to(device))
label_item = Variable(label_item.to(device))
outputs = net(image_item)
loss_obj = loss(outputs, label_item)
total_eval_loss += loss_obj.item()
_, max_index = torch.max(outputs, 1)
pred_label = max_index.cpu().numpy()
true_label = label_item.cpu().numpy()
total_corrects += np.sum(pred_label == true_label)
return total_eval_loss / float(len(val_loader)), total_corrects / val_data_len
def test():
model.load_state_dict(torch.load(save_model_path, map_location=lambda storage, loc: storage),
strict=True)
model.eval()
total_corrects = 0
start = time.time() * 1000
with torch.no_grad():
for i, (image, label) in enumerate(test_loader):
image = Variable(image.to(device))
label = Variable(label.to(device))
pred = model(image)
max_value, max_index = torch.max(pred, 1)
pred_label = max_index.cpu().numpy()
true_label = label.cpu().numpy()
total_corrects += np.sum(pred_label == true_label)
end = time.time() * 1000
print(f'correct rate:{total_corrects / test_data_len}, run v: {int((end - start) / test_data_len)}')
def main():
best_acc = 0
for epoch_index in range(epoches):
loss, train_acc = train(model, optim_adam, loss_cross)
print(epoch_index + 1, loss, train_acc)
loss, val_acc = evaluate(model, loss_cross)
print(epoch_index + 1, loss, val_acc)
if val_acc > best_acc:
print(epoch_index + 1, val_acc)
torch.save(model.state_dict(), save_model_path)
best_acc = val_acc
if __name__ == '__main__':
main()
test()