本文已参与「新人创作礼」活动,一起开启掘金创作之路。
原始链接:github.com/jfzhang95/p…
1.数据准备
images文件夹和labels文件夹内的图像和标签名是一一对应的,名字是一样的,标签的具体内容应该是0,1,2,3这样代表类别的数据。文件夹名字最好和我的都一样,因为代码里有的地方写了文件名。
2.数据相关代码 (1)数据读入,创建对应的脚本放进对应的位置 dataloaders/datasets/own_data.py
import os
import cv2
import random
import numpy as np
from shutil import copyfile, move
from osgeo import gdal
from PIL import Image
import torch
import torch.utils.data as data
from torchvision import transforms
from dataloaders import custom_transforms as tr
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename,im_proj,im_geotrans,im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
def gdal_loader(img_path, mask_path):
_,_,_,_,img = read_img(img_path)
_,_,_,_,mask = read_img(mask_path)
img = img[0:3] / 255.0 #这里因为数据原来是4波段,取了三个波段训练,4波段要改的地方比较麻烦,暂时不改
return img, mask
def read_own_data(root_path, split = 'train'):
images = []
masks = []
image_root = os.path.join(root_path, split + '/images')
gt_root = os.path.join(root_path, split + '/labels')
for image_name in os.listdir(image_root):
image_path = os.path.join(image_root, image_name)
label_path = os.path.join(gt_root, image_name)
images.append(image_path)
masks.append(label_path)
return images, masks
def own_data_loader(img_path, mask_path):
img = cv2.imread(img_path)
img = img/255.0
# img = cv2.resize(img, (512, 512))
# mask = np.array(Image.open(mask_path))
# mask = cv2.resize(mask, (512, 512))
mask = np.expand_dims(mask, axis=2)
# img = np.array(img, np.float32).transpose(2, 0, 1) / 255.0 * 3.2 - 1.6
# mask = np.array(mask, np.float32).transpose(2, 0, 1) / 255.0
# mask[mask >= 0.5] = 1
# mask[mask <= 0.5] = 0
mask = mask/255.0
# mask = abs(mask-1)
img = np.array(img, np.float32).transpose(2, 0, 1)
mask = np.array(mask, np.float32).transpose(2, 0, 1)
return img, mask
class ImageFolder(data.Dataset):
NUM_CLASSES = 4
def __init__(self, args, split='train'):
self.args = args
self.root = self.args.root_path
self.split = split
self.images, self.labels = read_own_data(self.root, self.split)
def transform_tr(self, sample):
composed_transforms = transforms.Compose([
# tr.RandomHorizontalFlip(),
# tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
# tr.RandomGaussianBlur(),
tr.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
tr.ToTensor()
])
return composed_transforms(sample)
def transform_val(self, sample):
composed_transforms = transforms.Compose([
# tr.FixScaleCrop(crop_size=self.args.crop_size),
tr.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
tr.ToTensor()
])
return composed_transforms(sample)
def __getitem__(self, index):
# img, mask = own_data_loader(self.images[index], self.labels[index])
img, mask = gdal_loader(self.images[index], self.labels[index])
# img = torch.Tensor(img)
# mask = torch.Tensor(mask)
sample = {'image': img, 'label': mask}
# sample = {'image': _img, 'label': _target}
# if self.split == "train":
# return self.transform_tr(sample)
# elif self.split == 'val':
# return self.transform_val(sample)
return sample
def __len__(self):
assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
return len(self.images)
(2)修改dataloaders/init.py文件,添加自己数据的调用入口
from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd, own_data
from torch.utils.data import DataLoader
def make_data_loader(args, **kwargs):
if args.dataset == 'pascal':
train_set = pascal.VOCSegmentation(args, split='train')
val_set = pascal.VOCSegmentation(args, split='val')
if args.use_sbd:
sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])
num_class = train_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
test_loader = None
return train_loader, val_loader, test_loader, num_class
elif args.dataset == 'cityscapes':
train_set = cityscapes.CityscapesSegmentation(args, split='train')
val_set = cityscapes.CityscapesSegmentation(args, split='val')
test_set = cityscapes.CityscapesSegmentation(args, split='test')
num_class = train_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)
return train_loader, val_loader, test_loader, num_class
elif args.dataset == 'coco':
train_set = coco.COCOSegmentation(args, split='train')
val_set = coco.COCOSegmentation(args, split='val')
num_class = train_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
test_loader = None
return train_loader, val_loader, test_loader, num_class
elif args.dataset == 'own':
train_set = own_data.ImageFolder(args, split='train')
val_set = own_data.ImageFolder(args, split='val')
num_class = train_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, drop_last=True)
test_loader = None
return train_loader, val_loader, test_loader, num_class
else:
raise NotImplementedError
3.训练代码修改,要改的地方不多,因为数据的导入已经改到适应训练了,下面有两个地方要改,给了中文注释的,认真看下 train.py,训练的时候会自动创建run文件夹放模型
import argparse
import os
import numpy as np
from tqdm import tqdm
from mypath import Path
from dataloaders import make_data_loader
from modeling.sync_batchnorm.replicate import patch_replication_callback
from modeling.deeplab import *
from utils.loss import SegmentationLosses
from utils.calculate_weights import calculate_weigths_labels
from utils.lr_scheduler import LR_Scheduler
from utils.saver import Saver
from utils.summaries import TensorboardSummary
from utils.metrics import Evaluator
class Trainer(object):
def __init__(self, args):
self.args = args
# Define Saver
self.saver = Saver(args)
self.saver.save_experiment_config()
# Define Tensorboard Summary
self.summary = TensorboardSummary(self.saver.experiment_dir)
self.writer = self.summary.create_summary()
# Define Dataloader
kwargs = {'num_workers': args.workers, 'pin_memory': True}
self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
# Define network
model = DeepLab(num_classes=self.nclass,
backbone=args.backbone,
output_stride=args.out_stride,
sync_bn=args.sync_bn,
freeze_bn=args.freeze_bn)
train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
{'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
# Define Optimizer
optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=args.nesterov)
# Define Criterion
# whether to use class balanced weights
if args.use_balanced_weights:
classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
if os.path.isfile(classes_weights_path):
weight = np.load(classes_weights_path)
else:
weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
weight = torch.from_numpy(weight.astype(np.float32))
else:
weight = None
self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
self.model, self.optimizer = model, optimizer
# Define Evaluator
self.evaluator = Evaluator(self.nclass)
# Define lr scheduler
self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
args.epochs, len(self.train_loader))
# Using cuda
if args.cuda:
self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
patch_replication_callback(self.model)
self.model = self.model.cuda()
# Resuming checkpoint
self.best_pred = 0.0
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
if args.cuda:
self.model.module.load_state_dict(checkpoint['state_dict'])
else:
self.model.load_state_dict(checkpoint['state_dict'])
if not args.ft:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.best_pred = checkpoint['best_pred']
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
# Clear start epoch if fine-tuning
if args.ft:
args.start_epoch = 0
def training(self, epoch):
train_loss = 0.0
self.model.train()
tbar = tqdm(self.train_loader)
num_img_tr = len(self.train_loader)
for i, sample in enumerate(tbar):
image, target = sample['image'], sample['label']
if self.args.cuda:
image, target = image.cuda().float(), target.cuda()
self.scheduler(self.optimizer, i, epoch, self.best_pred)
self.optimizer.zero_grad()
output = self.model(image)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
train_loss += loss.item()
tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
# Show 10 * 3 inference results each epoch
# if i % (num_img_tr // 10) == 0:
# global_step = i + num_img_tr * epoch
# self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) #这里注释掉,因为原始的可视化针对的是21类的voc数据,这里不好用,暂时不要
self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
print('Loss: %.3f' % train_loss)
if self.args.no_val:
# save checkpoint every epoch
is_best = False
self.saver.save_checkpoint({
'epoch': epoch + 1,
'state_dict': self.model.module.state_dict(),
'optimizer': self.optimizer.state_dict(),
'best_pred': self.best_pred,
}, is_best)
def validation(self, epoch):
self.model.eval()
self.evaluator.reset()
tbar = tqdm(self.val_loader, desc='\r')
test_loss = 0.0
for i, sample in enumerate(tbar):
image, target = sample['image'], sample['label']
if self.args.cuda:
image, target = image.cuda().float(), target.cuda()
with torch.no_grad():
output = self.model(image)
loss = self.criterion(output, target)
test_loss += loss.item()
tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
pred = output.data.cpu().numpy()
target = target.cpu().numpy()
pred = np.argmax(pred, axis=1)
# Add batch sample into evaluator
self.evaluator.add_batch(target, pred)
# Fast test during the training
Acc = self.evaluator.Pixel_Accuracy()
Acc_class = self.evaluator.Pixel_Accuracy_Class()
mIoU = self.evaluator.Mean_Intersection_over_Union()
FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
self.writer.add_scalar('val/mIoU', mIoU, epoch)
self.writer.add_scalar('val/Acc', Acc, epoch)
self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
print('Validation:')
print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
print('Loss: %.3f' % test_loss)
new_pred = mIoU
if new_pred > self.best_pred:
is_best = True
self.best_pred = new_pred
self.saver.save_checkpoint({
'epoch': epoch + 1,
'state_dict': self.model.module.state_dict(),
'optimizer': self.optimizer.state_dict(),
'best_pred': self.best_pred,
}, is_best)
def main():
parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training")
parser.add_argument('--backbone', type=str, default='resnet',
choices=['resnet', 'xception', 'drn', 'mobilenet'],
help='backbone name (default: resnet)')
parser.add_argument('--out-stride', type=int, default=16,
help='network output stride (default: 8)')
parser.add_argument('--dataset', type=str, default='pascal',
choices=['pascal', 'coco', 'cityscapes', 'own'],
help='dataset name (default: pascal)')
parser.add_argument('--root_path', type=str, default='./data/',
help='the own path root') **#这里加下根目录的参数**
parser.add_argument('--use-sbd', action='store_true', default=True,
help='whether to use SBD dataset (default: True)')
parser.add_argument('--workers', type=int, default=4,
metavar='N', help='dataloader threads')
parser.add_argument('--base-size', type=int, default=512,
help='base image size')
parser.add_argument('--crop-size', type=int, default=448,
help='crop image size')
parser.add_argument('--sync-bn', type=bool, default=None,
help='whether to use sync bn (default: auto)')
parser.add_argument('--freeze-bn', type=bool, default=False,
help='whether to freeze bn parameters (default: False)')
parser.add_argument('--loss-type', type=str, default='ce',
choices=['ce', 'focal'],
help='loss func type (default: ce)')
# training hyper params
parser.add_argument('--epochs', type=int, default=None, metavar='N',
help='number of epochs to train (default: auto)')
parser.add_argument('--start_epoch', type=int, default=0,
metavar='N', help='start epochs (default:0)')
parser.add_argument('--batch-size', type=int, default=None,
metavar='N', help='input batch size for \
training (default: auto)')
parser.add_argument('--test-batch-size', type=int, default=None,
metavar='N', help='input batch size for \
testing (default: auto)')
parser.add_argument('--use-balanced-weights', action='store_true', default=False,
help='whether to use balanced weights (default: False)')
# optimizer params
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (default: auto)')
parser.add_argument('--lr-scheduler', type=str, default='poly',
choices=['poly', 'step', 'cos'],
help='lr scheduler mode: (default: poly)')
parser.add_argument('--momentum', type=float, default=0.9,
metavar='M', help='momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=5e-4,
metavar='M', help='w-decay (default: 5e-4)')
parser.add_argument('--nesterov', action='store_true', default=False,
help='whether use nesterov (default: False)')
# cuda, seed and logging
parser.add_argument('--no-cuda', action='store_true', default=
False, help='disables CUDA training')
parser.add_argument('--gpu-ids', type=str, default='0',
help='use which gpu to train, must be a \
comma-separated list of integers only (default=0)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
# checking point
parser.add_argument('--resume', type=str, default=None,
help='put the path to resuming file if needed')
parser.add_argument('--checkname', type=str, default=None,
help='set the checkpoint name')
# finetuning pre-trained models
parser.add_argument('--ft', action='store_true', default=False,
help='finetuning on a different dataset')
# evaluation option
parser.add_argument('--eval-interval', type=int, default=1,
help='evaluuation interval (default: 1)')
parser.add_argument('--no-val', action='store_true', default=False,
help='skip validation during training')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
try:
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
except ValueError:
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
if args.sync_bn is None:
if args.cuda and len(args.gpu_ids) > 1:
args.sync_bn = True
else:
args.sync_bn = False
# default settings for epochs, batch_size and lr
if args.epochs is None:
epoches = {
'coco': 30,
'cityscapes': 200,
'pascal': 50,
}
args.epochs = epoches[args.dataset.lower()]
if args.batch_size is None:
args.batch_size = 4 * len(args.gpu_ids)
if args.test_batch_size is None:
args.test_batch_size = args.batch_size
if args.lr is None:
lrs = {
'coco': 0.1,
'cityscapes': 0.01,
'pascal': 0.007,
}
args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size
if args.checkname is None:
args.checkname = 'deeplab-'+str(args.backbone)
print(args)
torch.manual_seed(args.seed)
trainer = Trainer(args)
print('Starting Epoch:', trainer.args.start_epoch)
print('Total Epoches:', trainer.args.epochs)
for epoch in range(trainer.args.start_epoch, trainer.args.epochs):
trainer.training(epoch)
if not trainer.args.no_val and epoch % args.eval_interval == (args.eval_interval - 1):
trainer.validation(epoch)
trainer.writer.close()
if __name__ == "__main__":
main()
4.预测代码,原始的代码没有预测,这里给出一下 inference.py
import argparse
import os
import numpy as np
import tqdm
import torch
import json
from osgeo import gdal
from easydict import EasyDict
from PIL import Image
from modeling.deeplab import *
# from utils.dataloader import seg_utils, make_data_loader
from utils.metrics import Evaluator
from torchvision import transforms
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename,im_proj,im_geotrans,im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
class Tester(object):
def __init__(self, args):
if not os.path.isfile(args.model):
raise RuntimeError("no checkpoint found at '{}'".format(args.model))
self.args = args
# self.color_map = seg_utils.get_pascal_labels()
self.nclass = args.num_class
#Define model
model = DeepLab(num_classes=self.nclass,
backbone=args.backbone,
output_stride=args.out_stride,
sync_bn=False,
freeze_bn=False)
self.model = model
device = torch.device('cpu')
checkpoint = torch.load(args.model, map_location=device)
self.model.load_state_dict(checkpoint['state_dict'])
self.evaluator = Evaluator(self.nclass)
# def save_image(self, array, id, op):
# text = 'gt'
# if op == 0:
# text = 'pred'
# file_name = id
# r = array.copy()
# g = array.copy()
# b = array.copy()
# for i in range(self.nclass):
# r[array == i] = self.color_map[i][0]
# g[array == i] = self.color_map[i][1]
# b[array == i] = self.color_map[i][2]
# rgb = np.dstack((r, g, b))
# save_img = Image.fromarray(rgb.astype('uint8'))
# save_img.save(self.args.save+os.sep+file_name)
def inference(self):
self.model.eval()
self.evaluator.reset()
DATA_DIR = self.args.data_path
SAVE_DIR = self.args.save_path
for idx, test_file in enumerate(os.listdir(DATA_DIR)):
if test_file == '.DS_Store':
continue
# test_img = Image.open(os.path.join(DATA_DIR, test_file)).convert('RGB')
im_proj,im_geotrans,im_width, im_height,test_img = read_img(os.path.join(DATA_DIR, test_file))
test_array = np.array(test_img).astype(np.float32)
full_name = os.path.split(test_file)
image_id, extension = image_id, extension = full_name[1][0:-4], full_name[1].split('.')[-1]
# Normalize
test_img = test_img[0:3]
test_img = test_img / 255.0
test_img = np.expand_dims(test_img, axis=0)
test_crop_tensor = torch.from_numpy(test_img)
test_crop_tensor = torch.tensor(test_crop_tensor, dtype=torch.float32)
with torch.no_grad():
output = self.model(test_crop_tensor)
pred = output.data.cpu().numpy()
pred = np.argmax(pred, axis=1)
inference_imgs = pred[0][:, :]
print('inference ... {}/{}'.format(idx+1, len(os.listdir(DATA_DIR))))
# gray mode
# save_image = Image.fromarray(inference_imgs.astype('uint8'))
# save_image.save(os.path.join(self.args.save_path,image_id+'.'+extension))
write_img(os.path.join(self.args.save_path,image_id+'.'+extension),im_proj,im_geotrans,inference_imgs.astype('uint8'))
def main():
with open('config.json') as f:
args = json.load(f)
args = EasyDict(args['inference'])
tester = Tester(args)
if args.inference:
print('predict...')
tester.inference()
if __name__ == "__main__":
main()
预测代码对应的json文件,放在同级目录 config.json
{
"inference": {
"inference": 1,
"num_class": 4,
"backbone": "resnet",
"out_stride": 16,
"crop_size": 448, #这个参数和train.py的参数部分保持一致
"data_path": "./deeplab-xception/data/val/images",
"save_path": "./deeplab-xception/result",
"model": "./deeplab-xception/run/own/deeplab-resnet/model_best.pth.tar"
}
}
训练使用的命令行命令:
python train.py --backbone resnet --lr 0.002 --workers 0 --epochs 50 --batch-size 3 --gpu-ids 0 --checkname deeplab-resnet --eval-interval 1 --dataset own
预测的时候改好json文件直接命令行运行inference.py就行了
因为一开始是四波段,用了gdal读取数据,没办法对应做数据增强的一些操作,感兴趣的可以把数据转为png直接用三波段参考pascal.py读入数据,用源码提供的数据增强试试