COVID-Net工程源码详解(五) - data.py解析

206 阅读7分钟

本文已参与 [新人创作礼] 活动,一起开启掘金创作之路。​

data.py源码如下:

import tensorflow as tf
from tensorflow import keras

import numpy as np
import os
import cv2

from tensorflow.keras.preprocessing.image import ImageDataGenerator

def crop_top(img, percent=0.15):
    offset = int(img.shape[0] * percent)
    return img[offset:]

def central_crop(img):
    size = min(img.shape[0], img.shape[1])
    offset_h = int((img.shape[0] - size) / 2)
    offset_w = int((img.shape[1] - size) / 2)
    return img[offset_h:offset_h + size, offset_w:offset_w + size]

def process_image_file(filepath, top_percent, size):
    img = cv2.imread(filepath)
    img = crop_top(img, percent=top_percent)
    img = central_crop(img)
    img = cv2.resize(img, (size, size))
    return img

def random_ratio_resize(img, prob=0.3, delta=0.1):
    if np.random.rand() >= prob:
        return img
    ratio = img.shape[0] / img.shape[1]
    ratio = np.random.uniform(max(ratio - delta, 0.01), ratio + delta)

    if ratio * img.shape[1] <= img.shape[1]:
        size = (int(img.shape[1] * ratio), img.shape[1])
    else:
        size = (img.shape[0], int(img.shape[0] / ratio))

    dh = img.shape[0] - size[1]
    top, bot = dh // 2, dh - dh // 2
    dw = img.shape[1] - size[0]
    left, right = dw // 2, dw - dw // 2

    if size[0] > 480 or size[1] > 480:
        print(img.shape, size, ratio)

    img = cv2.resize(img, size)
    img = cv2.copyMakeBorder(img, top, bot, left, right, cv2.BORDER_CONSTANT,
                             (0, 0, 0))

    if img.shape[0] != 480 or img.shape[1] != 480:
        raise ValueError(img.shape, size)
    return img

_augmentation_transform = ImageDataGenerator(
    featurewise_center=False,
    featurewise_std_normalization=False,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    brightness_range=(0.9, 1.1),
    zoom_range=(0.85, 1.15),
    fill_mode='constant',
    cval=0.,
)

def apply_augmentation(img):
    img = random_ratio_resize(img)
    img = _augmentation_transform.random_transform(img)
    return img

def _process_csv_file(file):
    with open(file, 'r') as fr:
        files = fr.readlines()
    return files


class BalanceCovidDataset(keras.utils.Sequence):
    'Generates data for Keras'

    def __init__(
            self,
            data_dir,
            csv_file,
            is_training=True,
            batch_size=8,
            input_shape=(224, 224),
            n_classes=3,
            num_channels=3,
            mapping={
                'normal': 0,
                'pneumonia': 1,
                'COVID-19': 2
            },
            shuffle=True,
            augmentation=apply_augmentation,
            covid_percent=0.3,
            class_weights=[1., 1., 6.],
            top_percent=0.08
    ):
        'Initialization'
        self.datadir = data_dir
        self.dataset = _process_csv_file(csv_file)
        self.is_training = is_training
        self.batch_size = batch_size
        self.N = len(self.dataset)
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.num_channels = num_channels
        self.mapping = mapping
        self.shuffle = True
        self.covid_percent = covid_percent
        self.class_weights = class_weights
        self.n = 0
        self.augmentation = augmentation
        self.top_percent = top_percent

        datasets = {'normal': [], 'pneumonia': [], 'COVID-19': []}
        for l in self.dataset:
            datasets[l.split()[2]].append(l)
        self.datasets = [
            datasets['normal'] + datasets['pneumonia'],
            datasets['COVID-19'],
        ]
        print(len(self.datasets[0]), len(self.datasets[1]))

        self.on_epoch_end()

    def __next__(self):
        # Get one batch of data
        batch_x, batch_y, weights = self.__getitem__(self.n)
        # Batch index
        self.n += 1

        # If we have processed the entire dataset then
        if self.n >= self.__len__():
            self.on_epoch_end
            self.n = 0

        return batch_x, batch_y, weights

    def __len__(self):
        return int(np.ceil(len(self.datasets[0]) / float(self.batch_size)))

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        if self.shuffle == True:
            for v in self.datasets:
                np.random.shuffle(v)

    def __getitem__(self, idx):
        batch_x, batch_y = np.zeros(
            (self.batch_size, *self.input_shape,
             self.num_channels)), np.zeros(self.batch_size)

        batch_files = self.datasets[0][idx * self.batch_size:(idx + 1) *
                                       self.batch_size]

        # upsample covid cases
        covid_size = max(int(len(batch_files) * self.covid_percent), 1)
        covid_inds = np.random.choice(np.arange(len(batch_files)),
                                      size=covid_size,
                                      replace=False)
        covid_files = np.random.choice(self.datasets[1],
                                       size=covid_size,
                                       replace=False)
        for i in range(covid_size):
            batch_files[covid_inds[i]] = covid_files[i]

        for i in range(len(batch_files)):
            sample = batch_files[i].split()

            if self.is_training:
                folder = 'train'
            else:
                folder = 'test'

            x = process_image_file(os.path.join(self.datadir, folder, sample[1]),
                                   self.top_percent,
                                   self.input_shape[0])

            if self.is_training and hasattr(self, 'augmentation'):
                x = self.augmentation(x)

            x = x.astype('float32') / 255.0
            y = self.mapping[sample[2]]

            batch_x[i] = x
            batch_y[i] = y

        class_weights = self.class_weights
        weights = np.take(class_weights, batch_y.astype('int64'))

        return batch_x, keras.utils.to_categorical(batch_y, num_classes=self.n_classes), weights

import tensorflow as tf
from tensorflow import keras

import numpy as np
import os
import cv2

导入各个库。

from tensorflow.keras.preprocessing.image import ImageDataGenerator

从tensorflow.keras.preprocessing.image导入ImageDataGenerator函数

以下内容参考:

keras的图像预处理全攻略(二)—— ImageDataGenerator 类_LoveMIss-Y的博客-CSDN博客_imagedatagenerator

【Tool】Keras 基础学习 III ImageDataGenerator() - 简书

ImageDataGenerator类的定义以及构造函数的参数详解

1. ImageDataGenerator类简介

ImageDataGenerator类在xxx\Lib\site-packages\keras_preprocessing\image.py中,定义如下:

class ImageDataGenerator(object):

"""Generate batches of tensor image data with real-time data augmentation.

The data will be looped over (in batches).

"""

这个类是做什么用的?ImageDataGenerator()是keras.preprocessing.image模块中的图片生成器,同时也可以在batch中实时对数据进行增强,扩充数据集大小,增强模型的泛化能力。比如进行旋转,变形,归一化等等,并且可以循环迭代。在Keras中,当数据量很多的时候需要使用model.fit_generator()方法,该方法接受的第一个参数就是一个生成器。简单来说就是:ImageDataGenerator()是keras.preprocessing.image模块中的图片生成器,可以每一次给模型“喂”一个batch_size大小的样本数据,同时也可以在每一个批次中对这batch_size个样本数据进行增强,扩充数据集大小,增强模型的泛化能力。比如进行旋转,变形,归一化等等。

总结起来就是以下两点:

(1)图片生成器,负责生成一个批次一个批次的图片,以生成器的形式给模型训练;

(2)对每一个批次的训练图片,适时地进行数据增强处理(data augmentation)。

2. 数据增强处理(data augmentation)

数据增强有两种方法。一种是事先执行所有转换,实质上会增强你的数据集的大小。另一种选项是在送入机器学习之前,在小批量(mini-batch)上执行这些转换。第一个选项叫做线下增强(offline augmentation)。这种方法适用于较小的数据集(smaller dataset)。你最终会增加一定的倍数的数据集,这个倍数等于你转换的个数。第二种方法叫做线上增强(online augmentation)或在飞行中增强(augmentation on the fly)。这种方法更适用于较大的数据集(larger datasets),因为你无法承受爆炸性增加的规模。

数据增强的手段有非常多种,这里指说一些代表性的。

  • 旋转 | 反射变换(Rotation/reflection): 随机旋转图像一定角度; 改变图像内容的朝向。
  • 翻转变换(flip): 沿着水平或者垂直方向翻转图像。
  • 缩放变换(zoom): 按照一定的比例放大或者缩小图像。
  • 平移变换(shift): 在图像平面上对图像以一定方式进行平移;可以采用随机或人为定义的方式指定平移范围和平移步长, 沿水平或竖直方向进行平移. 改变图像内容的位置。
  • 尺度变换(scale): 对图像按照指定的尺度因子, 进行放大或缩小; 或者参照SIFT特征提取思想, 利用指定的尺度因子对图像滤波构造尺度空间. 改变图像内容的大小或模糊程度。
  • 对比度变换(contrast): 在图像的HSV颜色空间,改变饱和度S和V亮度分量,保持色调H不变. 对每个像素的S和V分量进行指数运算(指数因子在0.25到4之间), 增加光照变化。
  • 噪声扰动(noise): 对图像的每个像素RGB进行随机扰动, 常用的噪声模式是椒盐噪声和高斯噪声。
  • 错切变换(shear):效果就是让所有点的x坐标(或者y坐标)保持不变,而对应的y坐标(或者x坐标)则按比例发生平移,且平移的大小和该点到x轴(或y轴)的垂直距离成正比。

3. ImageDataGenerator类的构造函数参数

keras.preprocessing.image.ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, zca_epsilon=1e-06, rotation_range=0.0, width_shift_range=0.0, height_shift_range=0.0, brightness_range=None, shear_range=0.0, zoom_range=0.0, channel_shift_range=0.0, fill_mode='nearest', cval=0.0, horizontal_flip=False, vertical_flip=False, rescale=None, preprocessing_function=None, data_format=None, validation_split=0.0)

  • featurewise_center: 布尔值。将输入数据的均值设置为 0,逐特征进行,对输入的图片每个通道减去每个通道对应均值。
  • samplewise_center: 布尔值。将每个样本的均值设置为 0,每张图片减去样本均值, 使得每个样本均值为0。
  • featurewise_std_normalization: Boolean. 布尔值。将每个输入(即每张图片)除以数据集(dataset)标准差,逐特征进行。
  • samplewise_std_normalization: 布尔值。将每个输入(即每张图片)除以其自身(图片本身)的标准差。

这里需要注意两个概念,所谓 featurewise指的是逐特征,它针对的是数据集dataset,而samplewise针对的是单个输入图片的本身。featurewise是从整个数据集的分布去考虑的,而samplewise只是针对自身图片 。

  • zca_epsilon: ZCA 白化的 epsilon 值,默认为 1e-6。
  • zca_whitening: 布尔值。是否应用 ZCA 白化。
  • rotation_range: 整数。随机旋转的度数范围。
  • width_shift_range: 它的值可以是浮点数、一维数组、整数。
    • float: 如果 <1,则是除以总宽度的值,或者如果 >=1,则为像素值。
    • 1-D 数组: 数组中的随机元素。
    • int: 来自间隔 (-width_shift_range, +width_shift_range) 之间的整数个像素。
    • width_shift_range=2 时,可能值是整数 [-1, 0, +1],与 width_shift_range=[-1, 0, +1] 相同;而 width_shift_range=1.0 时,可能值是 [-1.0, +1.0) 之间的浮点数。
  • height_shift_range: 浮点数、一维数组或整数(同width_shift_range)。
    • float: 如果 <1,则是除以总宽度的值,或者如果 >=1,则为像素值。
    • 1-D array-like: 数组中的随机元素。
    • int: 来自间隔 (-height_shift_range, +height_shift_range) 之间的整数个像素。
    • height_shift_range=2 时,可能值是整数 [-1, 0, +1],与 height_shift_range=[-1, 0, +1]相同;而 height_shift_range=1.0 时,可能值是 [-1.0, +1.0) 之间的浮点数。
  • shear_range: 浮点数。剪切强度(以弧度逆时针方向剪切角度)。所谓shear_range就是错切变换,效果就是让所有点的x坐标(或者y坐标)保持不变,而对应的y坐标(或者x坐标)则按比例发生平移,且平移的大小和该点到x轴(或y轴)的垂直距离成正比。
  • brightness_range: 两个浮点数组成的元组或者是列表,像素的亮度会在这个范围之类随机确定。
  • zoom_range: 浮点数 或 [lower, upper]。随机缩放范围。如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]。zoom_range参数可以让图片在长或宽的方向进行放大,可以理解为某方向的resize,因此这个参数可以是一个数或者是一个list。当给出一个数时,图片同时在长宽两个方向进行同等程度的放缩操作;当给出一个list时,则代表[width_zoom_range, height_zoom_range],即分别对长宽进行不同程度的放缩。而参数大于0小于1时,执行的是放大操作,当参数大于1时,执