本文已参与 [新人创作礼] 活动,一起开启掘金创作之路。
data.py源码如下:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def crop_top(img, percent=0.15):
offset = int(img.shape[0] * percent)
return img[offset:]
def central_crop(img):
size = min(img.shape[0], img.shape[1])
offset_h = int((img.shape[0] - size) / 2)
offset_w = int((img.shape[1] - size) / 2)
return img[offset_h:offset_h + size, offset_w:offset_w + size]
def process_image_file(filepath, top_percent, size):
img = cv2.imread(filepath)
img = crop_top(img, percent=top_percent)
img = central_crop(img)
img = cv2.resize(img, (size, size))
return img
def random_ratio_resize(img, prob=0.3, delta=0.1):
if np.random.rand() >= prob:
return img
ratio = img.shape[0] / img.shape[1]
ratio = np.random.uniform(max(ratio - delta, 0.01), ratio + delta)
if ratio * img.shape[1] <= img.shape[1]:
size = (int(img.shape[1] * ratio), img.shape[1])
else:
size = (img.shape[0], int(img.shape[0] / ratio))
dh = img.shape[0] - size[1]
top, bot = dh // 2, dh - dh // 2
dw = img.shape[1] - size[0]
left, right = dw // 2, dw - dw // 2
if size[0] > 480 or size[1] > 480:
print(img.shape, size, ratio)
img = cv2.resize(img, size)
img = cv2.copyMakeBorder(img, top, bot, left, right, cv2.BORDER_CONSTANT,
(0, 0, 0))
if img.shape[0] != 480 or img.shape[1] != 480:
raise ValueError(img.shape, size)
return img
_augmentation_transform = ImageDataGenerator(
featurewise_center=False,
featurewise_std_normalization=False,
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
brightness_range=(0.9, 1.1),
zoom_range=(0.85, 1.15),
fill_mode='constant',
cval=0.,
)
def apply_augmentation(img):
img = random_ratio_resize(img)
img = _augmentation_transform.random_transform(img)
return img
def _process_csv_file(file):
with open(file, 'r') as fr:
files = fr.readlines()
return files
class BalanceCovidDataset(keras.utils.Sequence):
'Generates data for Keras'
def __init__(
self,
data_dir,
csv_file,
is_training=True,
batch_size=8,
input_shape=(224, 224),
n_classes=3,
num_channels=3,
mapping={
'normal': 0,
'pneumonia': 1,
'COVID-19': 2
},
shuffle=True,
augmentation=apply_augmentation,
covid_percent=0.3,
class_weights=[1., 1., 6.],
top_percent=0.08
):
'Initialization'
self.datadir = data_dir
self.dataset = _process_csv_file(csv_file)
self.is_training = is_training
self.batch_size = batch_size
self.N = len(self.dataset)
self.input_shape = input_shape
self.n_classes = n_classes
self.num_channels = num_channels
self.mapping = mapping
self.shuffle = True
self.covid_percent = covid_percent
self.class_weights = class_weights
self.n = 0
self.augmentation = augmentation
self.top_percent = top_percent
datasets = {'normal': [], 'pneumonia': [], 'COVID-19': []}
for l in self.dataset:
datasets[l.split()[2]].append(l)
self.datasets = [
datasets['normal'] + datasets['pneumonia'],
datasets['COVID-19'],
]
print(len(self.datasets[0]), len(self.datasets[1]))
self.on_epoch_end()
def __next__(self):
# Get one batch of data
batch_x, batch_y, weights = self.__getitem__(self.n)
# Batch index
self.n += 1
# If we have processed the entire dataset then
if self.n >= self.__len__():
self.on_epoch_end
self.n = 0
return batch_x, batch_y, weights
def __len__(self):
return int(np.ceil(len(self.datasets[0]) / float(self.batch_size)))
def on_epoch_end(self):
'Updates indexes after each epoch'
if self.shuffle == True:
for v in self.datasets:
np.random.shuffle(v)
def __getitem__(self, idx):
batch_x, batch_y = np.zeros(
(self.batch_size, *self.input_shape,
self.num_channels)), np.zeros(self.batch_size)
batch_files = self.datasets[0][idx * self.batch_size:(idx + 1) *
self.batch_size]
# upsample covid cases
covid_size = max(int(len(batch_files) * self.covid_percent), 1)
covid_inds = np.random.choice(np.arange(len(batch_files)),
size=covid_size,
replace=False)
covid_files = np.random.choice(self.datasets[1],
size=covid_size,
replace=False)
for i in range(covid_size):
batch_files[covid_inds[i]] = covid_files[i]
for i in range(len(batch_files)):
sample = batch_files[i].split()
if self.is_training:
folder = 'train'
else:
folder = 'test'
x = process_image_file(os.path.join(self.datadir, folder, sample[1]),
self.top_percent,
self.input_shape[0])
if self.is_training and hasattr(self, 'augmentation'):
x = self.augmentation(x)
x = x.astype('float32') / 255.0
y = self.mapping[sample[2]]
batch_x[i] = x
batch_y[i] = y
class_weights = self.class_weights
weights = np.take(class_weights, batch_y.astype('int64'))
return batch_x, keras.utils.to_categorical(batch_y, num_classes=self.n_classes), weights
import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
import cv2
导入各个库。
from tensorflow.keras.preprocessing.image import ImageDataGenerator
从tensorflow.keras.preprocessing.image导入ImageDataGenerator函数
以下内容参考:
keras的图像预处理全攻略(二)—— ImageDataGenerator 类_LoveMIss-Y的博客-CSDN博客_imagedatagenerator
【Tool】Keras 基础学习 III ImageDataGenerator() - 简书
ImageDataGenerator类的定义以及构造函数的参数详解
1. ImageDataGenerator类简介
ImageDataGenerator类在xxx\Lib\site-packages\keras_preprocessing\image.py中,定义如下:
class ImageDataGenerator(object):
"""Generate batches of tensor image data with real-time data augmentation.
The data will be looped over (in batches).
"""
这个类是做什么用的?ImageDataGenerator()是keras.preprocessing.image模块中的图片生成器,同时也可以在batch中实时对数据进行增强,扩充数据集大小,增强模型的泛化能力。比如进行旋转,变形,归一化等等,并且可以循环迭代。在Keras中,当数据量很多的时候需要使用model.fit_generator()方法,该方法接受的第一个参数就是一个生成器。简单来说就是:ImageDataGenerator()是keras.preprocessing.image模块中的图片生成器,可以每一次给模型“喂”一个batch_size大小的样本数据,同时也可以在每一个批次中对这batch_size个样本数据进行增强,扩充数据集大小,增强模型的泛化能力。比如进行旋转,变形,归一化等等。
总结起来就是以下两点:
(1)图片生成器,负责生成一个批次一个批次的图片,以生成器的形式给模型训练;
(2)对每一个批次的训练图片,适时地进行数据增强处理(data augmentation)。
2. 数据增强处理(data augmentation)
数据增强有两种方法。一种是事先执行所有转换,实质上会增强你的数据集的大小。另一种选项是在送入机器学习之前,在小批量(mini-batch)上执行这些转换。第一个选项叫做线下增强(offline augmentation)。这种方法适用于较小的数据集(smaller dataset)。你最终会增加一定的倍数的数据集,这个倍数等于你转换的个数。第二种方法叫做线上增强(online augmentation)或在飞行中增强(augmentation on the fly)。这种方法更适用于较大的数据集(larger datasets),因为你无法承受爆炸性增加的规模。
数据增强的手段有非常多种,这里指说一些代表性的。
- 旋转 | 反射变换(Rotation/reflection): 随机旋转图像一定角度; 改变图像内容的朝向。
- 翻转变换(flip): 沿着水平或者垂直方向翻转图像。
- 缩放变换(zoom): 按照一定的比例放大或者缩小图像。
- 平移变换(shift): 在图像平面上对图像以一定方式进行平移;可以采用随机或人为定义的方式指定平移范围和平移步长, 沿水平或竖直方向进行平移. 改变图像内容的位置。
- 尺度变换(scale): 对图像按照指定的尺度因子, 进行放大或缩小; 或者参照SIFT特征提取思想, 利用指定的尺度因子对图像滤波构造尺度空间. 改变图像内容的大小或模糊程度。
- 对比度变换(contrast): 在图像的HSV颜色空间,改变饱和度S和V亮度分量,保持色调H不变. 对每个像素的S和V分量进行指数运算(指数因子在0.25到4之间), 增加光照变化。
- 噪声扰动(noise): 对图像的每个像素RGB进行随机扰动, 常用的噪声模式是椒盐噪声和高斯噪声。
- 错切变换(shear):效果就是让所有点的x坐标(或者y坐标)保持不变,而对应的y坐标(或者x坐标)则按比例发生平移,且平移的大小和该点到x轴(或y轴)的垂直距离成正比。
3. ImageDataGenerator类的构造函数参数
keras.preprocessing.image.ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, zca_epsilon=1e-06, rotation_range=0.0, width_shift_range=0.0, height_shift_range=0.0, brightness_range=None, shear_range=0.0, zoom_range=0.0, channel_shift_range=0.0, fill_mode='nearest', cval=0.0, horizontal_flip=False, vertical_flip=False, rescale=None, preprocessing_function=None, data_format=None, validation_split=0.0)
- featurewise_center: 布尔值。将输入数据的均值设置为 0,逐特征进行,对输入的图片每个通道减去每个通道对应均值。
- samplewise_center: 布尔值。将每个样本的均值设置为 0,每张图片减去样本均值, 使得每个样本均值为0。
- featurewise_std_normalization: Boolean. 布尔值。将每个输入(即每张图片)除以数据集(dataset)标准差,逐特征进行。
- samplewise_std_normalization: 布尔值。将每个输入(即每张图片)除以其自身(图片本身)的标准差。
这里需要注意两个概念,所谓 featurewise指的是逐特征,它针对的是数据集dataset,而samplewise针对的是单个输入图片的本身。featurewise是从整个数据集的分布去考虑的,而samplewise只是针对自身图片 。
- zca_epsilon: ZCA 白化的 epsilon 值,默认为 1e-6。
- zca_whitening: 布尔值。是否应用 ZCA 白化。
- rotation_range: 整数。随机旋转的度数范围。
- width_shift_range: 它的值可以是浮点数、一维数组、整数。
- float: 如果 <1,则是除以总宽度的值,或者如果 >=1,则为像素值。
- 1-D 数组: 数组中的随机元素。
- int: 来自间隔
(-width_shift_range, +width_shift_range)之间的整数个像素。 width_shift_range=2时,可能值是整数[-1, 0, +1],与width_shift_range=[-1, 0, +1]相同;而width_shift_range=1.0时,可能值是[-1.0, +1.0)之间的浮点数。
- height_shift_range: 浮点数、一维数组或整数(同width_shift_range)。
- float: 如果 <1,则是除以总宽度的值,或者如果 >=1,则为像素值。
- 1-D array-like: 数组中的随机元素。
- int: 来自间隔
(-height_shift_range, +height_shift_range)之间的整数个像素。 height_shift_range=2时,可能值是整数[-1, 0, +1],与height_shift_range=[-1, 0, +1]相同;而height_shift_range=1.0时,可能值是[-1.0, +1.0)之间的浮点数。
- shear_range: 浮点数。剪切强度(以弧度逆时针方向剪切角度)。所谓shear_range就是错切变换,效果就是让所有点的x坐标(或者y坐标)保持不变,而对应的y坐标(或者x坐标)则按比例发生平移,且平移的大小和该点到x轴(或y轴)的垂直距离成正比。
- brightness_range: 两个浮点数组成的元组或者是列表,像素的亮度会在这个范围之类随机确定。
- zoom_range: 浮点数 或
[lower, upper]。随机缩放范围。如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]。zoom_range参数可以让图片在长或宽的方向进行放大,可以理解为某方向的resize,因此这个参数可以是一个数或者是一个list。当给出一个数时,图片同时在长宽两个方向进行同等程度的放缩操作;当给出一个list时,则代表[width_zoom_range, height_zoom_range],即分别对长宽进行不同程度的放缩。而参数大于0小于1时,执行的是放大操作,当参数大于1时,执