摘要
FastViT是一种混合ViT架构,它通过引入一种新型的token混合运算符RepMixer来达到最先进的延迟-准确性权衡。RepMixer通过消除网络中的跳过连接来降低内存访问成本。FastViT进一步应用训练时间过度参数化和大核卷积来提高准确性,并根据经验表明这些选择对延迟的影响最小。实验结果表明,FastViT在移动设备上的速度比最近的混合Transformer架构CMT快3.5倍,比EfficientNet快4.9倍,比ConvNeXt快1.9倍。在相似的延迟下,FastViT在ImageNet上的Top-1精度比MobileOne高出4.2%。此外,FastViT模型能够较好的适应域外和破损数据,相较于其它SOTA架构具备很强的鲁棒性和泛化性能。
官方开源代码点这里 全部代码在这里 某度云太慢了,也欢迎私信我取代码。
这篇文章是利用植物的数据集,借鉴了csdn博主AI浩的文章。下面进行实战环节。
实战
配置环境
# python环境及torch环境
conda create -n fastvit python=3.9
conda activate fastvit
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
所需要的安装包
# timm
pip install timm
# scipy
pip install scipy
# grad -cam
pip install grad-cam
# mmcv
pip install -U openmim
mim install mmcv
## 数据增强Cutout和Mixup
pip install torchtoolbox
下面进入代码环节
项目结构如下
- data文件夹是数据集文件,具体内容见网盘。
- models文件夹,FastVit的模型文件,来源于官方开源的代码中。
- test测试集,详情见网盘
- timm为工具包 项目的具体过程为
- 执行makedata.py文件,准备数据集,创建训练集和验证集。
# makedata.py
import glob
import os
import shutil
image_list=glob.glob('data1/*/*.*')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
print('true')
#os.rmdir(file_dir)
shutil.rmtree(file_dir)#删除再建立
os.makedirs(file_dir)
else:
os.makedirs(file_dir)
from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.2, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
file_class=file.replace("\","/").split('/')[-2]
file_name=file.replace("\","/").split('/')[-1]
file_class=os.path.join(train_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class + '/' + file_name)
for file in val_files:
file_class=file.replace("\","/").split('/')[-2]
file_name=file.replace("\","/").split('/')[-1]
file_class=os.path.join(val_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class + '/' + file_name)
- 执行训练文件train.py
import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.utils import accuracy, AverageMeter, ModelEma
from sklearn.metrics import classification_report
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from models.fastvit import fastvit_t8
from torch.autograd import Variable
from torchvision import datasets
torch.backends.cudnn.benchmark = False
import warnings
warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch,model_ema):
model.train()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device, non_blocking=True), Variable(target).to(device,non_blocking=True)
samples, targets = mixup_fn(data, target)
output = model(samples)
optimizer.zero_grad()
if use_amp:
with torch.cuda.amp.autocast():
loss = torch.nan_to_num(criterion_train(output, targets))
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)
# Unscales gradients and calls
# or skips optimizer.step()
scaler.step(optimizer)
# Updates the scale for next iteration
scaler.update()
else:
loss = criterion_train(output, targets)
loss.backward()
# torch.nn.utils.clip_grad_norm_(models.parameters(), CLIP_GRAD)
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
lr = optimizer.state_dict()['param_groups'][0]['lr']
loss_meter.update(loss.item(), target.size(0))
acc1, acc5 = accuracy(output, target, topk=(1, 5))
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc5_meter.update(acc5.item(), target.size(0))
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
ave_loss =loss_meter.avg
acc = acc1_meter.avg
print('epoch:{}\tloss:{:.2f}\tacc:{:.2f}'.format(epoch, ave_loss, acc))
return ave_loss, acc
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
global Best_ACC
model.eval()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
val_list = []
pred_list = []
for data, target in test_loader:
for t in target:
val_list.append(t.data.item())
data, target = data.to(device,non_blocking=True), target.to(device,non_blocking=True)
output = model(data)
loss = criterion_val(output, target)
_, pred = torch.max(output.data, 1)
for p in pred:
pred_list.append(p.data.item())
acc1, acc5 = accuracy(output, target, topk=(1, 5))
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc5_meter.update(acc5.item(), target.size(0))
acc = acc1_meter.avg
print('\nVal set: Average loss: {:.4f}\tAcc1:{:.3f}%\tAcc5:{:.3f}%\n'.format(
loss_meter.avg, acc, acc5_meter.avg))
if acc > Best_ACC:
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), file_dir + '/' + 'best.pth')
else:
torch.save(model.state_dict(), file_dir + '/' + 'best.pth')
Best_ACC = acc
if isinstance(model, torch.nn.DataParallel):
state = {
'epoch': epoch,
'state_dict': model.module.state_dict(),
'Best_ACC':Best_ACC
}
if use_ema:
state['state_dict_ema']=model.module.state_dict()
torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
else:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'Best_ACC': Best_ACC
}
if use_ema:
state['state_dict_ema']=model.state_dict()
torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
return val_list, pred_list, loss_meter.avg, acc
def seed_everything(seed=42):
os.environ['PYHTONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
#创建保存模型的文件夹
file_dir = 'checkpoints/FastVit/'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir,exist_ok=True)
else:
os.makedirs(file_dir)
# 设置全局参数
model_lr = 1e-4
BATCH_SIZE = 16
EPOCHS = 300
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
use_amp = True # 是否使用混合精度
use_dp = True #是否开启dp方式的多卡训练
classes = 12
resume =None
CLIP_GRAD = 5.0
Best_ACC = 0 #记录最高得分
use_ema=True
model_ema_decay=0.9998
start_epoch=1
seed=1
seed_everything(seed)
# 数据预处理7
transform = transforms.Compose([
transforms.RandomRotation(10),
transforms.GaussianBlur(kernel_size=(5,5),sigma=(0.1, 3.0)),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std= [0.18507297, 0.18050247, 0.16784933])
])
transform_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std= [0.18507297, 0.18050247, 0.16784933])
])
mixup_fn = Mixup(
mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
prob=0.1, switch_prob=0.5, mode='batch',
label_smoothing=0.1, num_classes=classes)
# 读取数据
dataset_train = datasets.ImageFolder('data/train', transform=transform)
dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
with open('class.txt', 'w') as file:
file.write(str(dataset_train.class_to_idx))
with open('class.json', 'w', encoding='utf-8') as file:
file.write(json.dumps(dataset_train.class_to_idx))
# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True,drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
# 实例化模型并且移动到GPU
criterion_train = SoftTargetCrossEntropy()
criterion_val = torch.nn.CrossEntropyLoss()
#设置模型
model_ft = fastvit_t8(pretrained=True)
print(model_ft)
num_fr = model_ft.head.in_features
model_ft.head = nn.Linear(num_fr,classes)
if resume:
model=torch.load(resume)
print(model['state_dict'].keys())
model_ft.load_state_dict(model['state_dict'])
Best_ACC=model['Best_ACC']
start_epoch=model['epoch']+1
model_ft.to(DEVICE)
print(model_ft)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.AdamW(model_ft.parameters(),lr=model_lr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-6)
if use_amp:
scaler = torch.cuda.amp.GradScaler()
if torch.cuda.device_count() > 1 and use_dp:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model_ft = torch.nn.DataParallel(model_ft)
if use_ema:
model_ema = ModelEma(
model_ft,
decay=model_ema_decay,
device=DEVICE,
resume=resume)
else:
model_ema=None
# 训练与验证
is_set_lr = False
log_dir = {}
train_loss_list, val_loss_list, train_acc_list, val_acc_list, epoch_list = [], [], [], [], []
if resume and os.path.isfile(file_dir+"result.json"):
with open(file_dir+'result.json', 'r', encoding='utf-8') as file:
logs = json.load(file)
train_acc_list = logs['train_acc']
train_loss_list = logs['train_loss']
val_acc_list = logs['val_acc']
val_loss_list = logs['val_loss']
epoch_list = logs['epoch_list']
for epoch in range(start_epoch, EPOCHS + 1):
epoch_list.append(epoch)
log_dir['epoch_list'] = epoch_list
train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema)
train_loss_list.append(train_loss)
train_acc_list.append(train_acc)
log_dir['train_acc'] = train_acc_list
log_dir['train_loss'] = train_loss_list
if use_ema:
val_list, pred_list, val_loss, val_acc = val(model_ema.ema, DEVICE, test_loader)
else:
val_list, pred_list, val_loss, val_acc = val(model_ft, DEVICE, test_loader)
val_loss_list.append(val_loss)
val_acc_list.append(val_acc)
log_dir['val_acc'] = val_acc_list
log_dir['val_loss'] = val_loss_list
log_dir['best_acc'] = Best_ACC
with open(file_dir + '/result.json', 'w', encoding='utf-8') as file:
file.write(json.dumps(log_dir))
print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))
if epoch < 600:
cosine_schedule.step()
else:
if not is_set_lr:
for param_group in optimizer.param_groups:
param_group["lr"] = 1e-6
is_set_lr = True
fig = plt.figure(1)
plt.plot(epoch_list, train_loss_list, 'r-', label=u'Train Loss')
# 显示图例
plt.plot(epoch_list, val_loss_list, 'b-', label=u'Val Loss')
plt.legend(["Train Loss", "Val Loss"], loc="upper right")
plt.xlabel(u'epoch')
plt.ylabel(u'loss')
plt.title('Model Loss ')
plt.savefig(file_dir + "/loss.png")
plt.close(1)
fig2 = plt.figure(2)
plt.plot(epoch_list, train_acc_list, 'r-', label=u'Train Acc')
plt.plot(epoch_list, val_acc_list, 'b-', label=u'Val Acc')
plt.legend(["Train Acc", "Val Acc"], loc="lower right")
plt.title("Model Acc")
plt.ylabel("acc")
plt.xlabel("epoch")
plt.savefig(file_dir + "/acc.png")
plt.close(2)
- 执行export_model.py 导出模型
#export_model.py
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
import os
import coremltools
import torch
from torch import nn
import models
from models.modules.mobileone import reparameterize_model
def parse_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--variant", type=str, default='fastvit_t8', help="Provide fastvit model variant name."
)
parser.add_argument(
"--output-dir",
type=str,
default=".",
help="Provide location to save exported models.",
)
parser.add_argument(
"--checkpoint",
type=str,
default='checkpoints/FastVit/best.pth',
help="Provide location of trained checkpoint.",
)
return parser
def export(variant: str, output_dir: str, checkpoint: str = None) -> None:
"""Method exports coreml package for mobile inference.
Args:
variant: FastViT model variant.
output_dir: Path to save exported model.
checkpoint: Path to trained checkpoint. Default: ``None``
"""
# Create output directory.
os.makedirs(output_dir, exist_ok=True)
# Random input tensor for tracing purposes.
inputs = torch.rand(1, 3, 256, 256)
inputs_tensor = [
coremltools.TensorType(
name="images",
shape=inputs.shape,
)
]
# Instantiate model variant.
model = getattr(models, variant)()
num_fr = model.head.in_features
model.head = nn.Linear(num_fr, 12)
print(f"Export and Convert Model: {variant}")
print(model)
if checkpoint is not None:
print(f"Load checkpoint {checkpoint}")
chkpt = torch.load(checkpoint)
print(chkpt.keys())
model.load_state_dict(chkpt)
# Always reparameterize before exporting.
reparameterized_model = reparameterize_model(model)
reparameterized_model.eval()
# Trace and export.
traced_model = torch.jit.trace(reparameterized_model, torch.Tensor(inputs))
output_path = os.path.join(output_dir, variant)
pt_name = output_path + ".pt"
traced_model.save(pt_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Script to export coreml package file")
parser = parse_args(parser)
args = parser.parse_args()
export(args.variant, args.output_dir, args.checkpoint)
4、执行测试,执行test.py 测试
# test.py
import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
'Common wheat', 'Fat Hen', 'Loose Silky-bent',
'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')
transform_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
chkpoint = torch.jit.load('fastvit_t8.pt')
chkpoint.eval()
chkpoint.to(DEVICE)
path = 'test/'
testList = os.listdir(path)
for file in testList:
img = Image.open(path + file)
img = transform_test(img)
img.unsqueeze_(0)
img = Variable(img).to(DEVICE)
out = chkpoint(img)
# Predict
_, pred = torch.max(out.data, 1)
print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))