Keras 数据预处理 ImageDataGenerator

·  阅读 2762
  • 本文主要介绍Keras中图像分类任务用到的图像预处理部分的内容。
  • 注意:并不是介绍Keras中所有的图像预处理函数。

1. 简介

使用Keras进行图像分类任务时,如果数据集较少(数据获取困难等),为了尽可能的充分利用有限数据的价值,可以进行数据增强处理。

通过一系列随机变换对数据进行提升,这样有利于抑制过拟合,提升模型的泛化能力。

Keras中提供了一个用于数据增强的类(Keras.preprocessing.image.ImageDataGenerator)来实现此功能。这个类可以:

  • 在训练过程中,设置要实施的随机变化
  • 通过.flow.flow_from_directory(directory)方法实例化一个针对图像batch的生成器,这些生成器可以被用做keras相关方法的输入,如fit_generator, evaluate_generatorpredict_generator

什么意思呢?——使用ImageDataGenerator类不仅可以在训练过程中进行图像的随机变化,增加训练数据;还附带赠送了获取数据batch生成器对象的功能,省去了手工再去获取batch数据的部分。

2. ImageDataGenerator类介绍

ImageDataGenerator类路径:keras/preprocessing/image.py

作用:通过实时数据增强生成批量图像数据向量。训练时该函数会无限循环生成数据,直到达到规定的epoch次数为止。

ImageDataGenerator继承于keras_preprocessing/image/image_data_generator.py中的ImageDataGenerator类。

# keras/preprocessing/image.py
class ImageDataGenerator(image.ImageDataGenerator):
    def __init__(self,
                 featurewise_center=False,
                 samplewise_center=False,
                 featurewise_std_normalization=False,
                 samplewise_std_normalization=False,
                 zca_whitening=False,
                 zca_epsilon=1e-6,
                 rotation_range=0,
                 width_shift_range=0.,
                 height_shift_range=0.,
                 brightness_range=None,
                 shear_range=0.,
                 zoom_range=0.,
                 channel_shift_range=0.,
                 fill_mode='nearest',
                 cval=0.,
                 horizontal_flip=False,
                 vertical_flip=False,
                 rescale=None,
                 preprocessing_function=None,
                 data_format=None,
                 validation_split=0.0,
                 dtype=None):
复制代码

参数

  • featurewise_center:布尔值,使输入数据集去中心化(均值为0),逐特征进行。
  • samplewise_center:布尔值,使输入数据的每个样本均值为0
  • featurewise_std_normalization:布尔值,将输入除以数据集的标准差以完成标准化, 按feature执行
  • samplewise_std_normalization:布尔值,将输入的每个样本除以其自身的标准差
  • zca_whitening:布尔值,对输入数据施加ZCA白化
  • zca_epsilon: ZCA使用的eposilon,默认1e-6
  • rotation_range:整数,图片随机转动的角度范围
  • width_shift_range:浮点数,一维数组或整数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
    • float:如果<1,则除以总宽度的值,如果>=1,则为宽度像素值
    • 一维数组:数组中的随机元素
    • 整型:来自间隔(-width_shift_range,width_shift_range)之间的整数个像素
    • width_shift_range=2:可能值是整数[-1,0,1],与width_shift_range=[-1,0,1]相同,而当width_shfit_range=1.0时,可能值是半开区间[-1.0,1.0]之间的浮点数(后半句没有理解)。
  • height_shift_range:浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度。具体含义与width_shift_range相同。
  • brightness_range:两个float组成的元组或列表。选择亮度值的范围
  • shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)
  • zoom_range:浮点数或[lower, upper]。随机缩放范围,如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]
  • channel_shift_range:浮点数,随机通道转换的范围。
  • fill_mode{"constant", "nearest", "reflect" or "wrap"} 之一。默认为'nearest'。输入边界以外的点根据给定的模式填充:
    • 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
    • 'nearest': aaaaaaaa|abcd|dddddddd
    • 'reflect': abcddcba|abcd|dcbaabcd
    • 'wrap': abcdabcd|abcd|abcdabcd
  • cval: 浮点数或整数。用于边界之外的点的值,当fill_mode = "constant"时。
  • horizontal_flip: 布尔值,随机水平翻转。
  • vertical_flip: 布尔值,随机垂直翻转。
  • rescale: 重缩放因子。默认为 None。如果是 None 或 0,不进行缩放,否则将数据乘以所提供的值(在应用任何其他转换之后
  • preprocessing_function:该函数应用于每个输入上,在图像被resize和增强之后运行。该函数接收一个参数,一张图像(秩为3的numpy tensor),同样输出一个相同shapeNumpy tensor
  • data_format:图像数据格式,{"channels_first", "channels_last"} 之一。"channels_last" 模式表示图像输入尺寸应该为(samples, height, width, channels)"channels_first" 模式表示输入尺寸应该为(samples, channels, height, width)。默认为 在 Keras 配置文件~/.keras/keras.json中的image_data_format值。如果你从未设置它,那它就是"channels_last"
  • validation_split:浮点型。保留用于验证集的图像比例(严格在0,1之间)
  • dtype:生成数组使用的数据类型。

使用示例

from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

data_generator = datagen.flow_from_directory('./datas/train', target_size=(224,224), batch_size=32)
复制代码

3. ImageDataGenerator类方法

该类的几个重要方法如下:

  • flow(): 该方法**输入数据(Numpy或元组形式)**和标签(可选),返回一个迭代器,格式是元组(x,y)(x)(x,y,sample_weight)。该方法还可以指定样本输出路径及前缀,格式,用于保存增强处理后的图像。
  • flow_from_directory(): 获取图像路径,生成批量增强数据。该方法只需指定数据所在的路径,而无需输入numpy形式的数据,也无需输入标签值,会自动返回对应的标签值。返回一个生成(x, y)元组的DirectoryIterator
  • flow_from_dataframe(): 输入数据为Pandas dataframe格式。返回生成(x, y) 元组的DataFrameIterator

注意事项

  • 主要区别是输入数据和输出数据的格式不同。
  • flow_from_directory()flow_from_dataframe()两个函数都将图像resize到指定大小。而flow()无此步骤。

3.1 fit()

该方法使数据生成器适合于某些样本数据,它根据样本数据数组计算与数据依赖转换相关的内部数据统计信息。

只有在featurewise_centerfeaturewise_std_normalizationzca_whitening为设置为True时才需要计算。

即实现对数据的去中心化/标准化/ZCA白化处理。 使用的数据均值、标准差都是数据自身的。

函数定义def fit(self, x, augment=False,rounds=1, seed=None)
参数:

  • x: 样本数据,秩为4,对于灰度图像,通道axis应该为1,如何是RGB数据 ,应该为3,如果是RGBA数据,应该为4。
  • augment:布尔型,默认False,是否应用随机增强
  • rounds:整型,默认1。如果augment=True,这是传递给数据使用的扩充量。
  • seed:整型,默认None,随机种子。

返回值
一个生成元组(x, y)Iterator,其中x是图像数据的Numpy数组(在单张图像输入时),或 Numpy 数组列表(在额外多个输入时),y 是对应的标签的Numpy数组。如果 'sample_weight'不是 None,生成的元组形式为(x, y, sample_weight)。如果 y 是 None, 只有Numpy数组x被返回。

具体实现

def fit(self, x,
    augment=False,
    rounds=1,
    seed=None):
    
    # 此处为合规性检测
    # 数据去中心化
    if self.featurewise_center:
        self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))
        broadcast_shape = [1, 1, 1]
        broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
        self.mean = np.reshape(self.mean, broadcast_shape)
        x -= self.mean
    # 数据标准化
    if self.featurewise_std_normalization:
        self.std = np.std(x, axis=(0, self.row_axis, self.col_axis))
        broadcast_shape = [1, 1, 1]
        broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
        self.std = np.reshape(self.std, broadcast_shape)
        x /= (self.std + 1e-6)
    # 数据ZAC白化处理
    if self.zca_whitening:
        if scipy is None:
            raise ImportError('Using zca_whitening requires SciPy. '
                              'Install SciPy.')
        flat_x = np.reshape(
            x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
        sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
        u, s, _ = linalg.svd(sigma)
        s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon)
        self.principal_components = (u * s_inv).dot(u.T)
复制代码

3.2 flow()

采集数据和标签数组,生成批量增强数据。

函数定义

def flow(self,
         x,
         y=None,
         batch_size=32,
         shuffle=True,
         sample_weight=None,
         seed=None,
         save_to_dir=None,
         save_prefix='',
         save_format='png',
         subset=None)
复制代码

参数

  • x: 输入数据。秩为 4 的 Numpy 矩阵或元组。如果是元组,第一个元素应该包含图像,第二个元素是另一个 Numpy 数组或一列 Numpy 数组,它们不经过任何修改就传递给输出。可用于将模型杂项数据与图像一起输入。对于灰度数据,图像数组的通道轴的值应该为 1,而对于 RGB 数据,其值应该为 3。
  • y:标签
  • batch_size:整型,默认32
  • shuffle:布尔型,默认True,是否混洗数据
  • sample_weight:样本权重
  • seed:默认None
  • save_to_dir:None 或 字符串(默认为 None)。这使您可以选择指定要保存的正在生成的增强图片的目录
  • save_prefix: 字符串(默认 '')。保存图片的文件名前缀(仅当 save_to_dir 设置时可用)。
  • save_format: "png", "jpeg" 之一(仅当 save_to_dir 设置时可用)。默认:"png"。
  • subset: 数据子集 ("training""validation"),如果 在 ImageDataGenerator 中设置了 validation_split

返回值:
一个生成元组 (x, y)Iterator,其中 x 是图像数据的Numpy 数组(在单张图像输入时),或 Numpy 数组列表(在额外多个输入时),y 是对应的标签的 Numpy 数组。如果 'sample_weight' 不是 None,生成的元组形式为(x, y, sample_weight)。如果 y 是 None, 只有 Numpy 数组 x 被返回。

内部调用数组迭代器类:

    return NumpyArrayIterator(
        x,
        y,
        self,
        batch_size=batch_size,
        shuffle=shuffle,
        sample_weight=sample_weight,
        seed=seed,
        data_format=self.data_format,
        save_to_dir=save_to_dir,
        save_prefix=save_prefix,
        save_format=save_format,
        subset=subset
    )
复制代码

3.3 flow_from_directory()

功能: 获取图像路径,生成批量增强数据。

函数定义

def flow_from_directory(self,
                directory,
                target_size=(256, 256),
                color_mode='rgb',
                classes=None,
                class_mode='categorical',
                batch_size=32,
                shuffle=True,
                seed=None,
                save_to_dir=None,
                save_prefix='',
                save_format='png',
                follow_links=False,
                subset=None,
                interpolation='nearest')
复制代码

参数

  • directory:目标目录的路径。每个类应该包含一个子目录。任何在子目录树下的 PNG, JPG, BMP, PPMTIF 图像,都将被包含在生成器中。
  • target_size:整数元组(height,width),默认:(256,256)。所有的图像将被调整到的尺寸.
  • color_mode:"grayscale","rbg"之一。默认:"rgb"。图像是否被转换成 1 或 3 个颜色通道。
  • classes:可选的类的子目录列表(例如['dogs', 'cats'])。默认:None。如果未提供,类的列表将自动从 directory下的子目录名称/结构中推断出来,其中每个子目录都将被作为不同的类(类名将按字典序映射到标签的索引)。包含从类名到类索引的映射的字典可以通过class_indices属性获得。
  • class_model"categorical", "binary", "sparse", "input"None之一。默认:"categorical"。决定返回的标签数组的类型:
    • "categorical"2D one-hot编码标签,
    • "binary"将是1D二进制标签,"sparse"将是 1D 整数标签,
    • "input"将是与输入图像相同的图像(主要用于自动编码器)。
    • 如果为 None,不返回标签(生成器将只产生批量的图像数据,对于 model.predict_generator(), model.evaluate_generator() 等很有用)。请注意,如果 class_modeNone,那么数据仍然需要驻留在 directory 的子目录中才能正常工作。
  • batch_size: 一批数据的大小(默认 32)。
  • shuffle:是否混洗数据(默认 True)
  • seed:可选随机种子,用于混洗和转换。
  • save_to_dir:None或字符串(默认None)。这使你可以最佳地指定正在生成的增强图片要保存的目录(用于可视化你在做什么)。
  • save_format:字符串。 保存图片的文件名前缀(仅当 save_to_dir 设置时可用)。
  • follow_links:是否跟踪类子目录中的符号链接(默认为 False)。
  • subset:数据子集("training""validation"),如果 在 ImageDataGenerator 中设置了 validation_split
  • interpolation:在目标大小与加载图像的大小不同时,用于重新采样图像的插值方法。 支持的方法有 "nearest", "bilinear", and "bicubic"。 如果安装了 1.1.3 以上版本的 PIL 的话,同样支持 "lanczos"。 如果安装了 3.4.0 以上版本的 PIL 的话,同样支持 "box""hamming"。 默认情况下,使用 "nearest"

返回值
一个生成(x, y)元组的 DirectoryIterator,其中 x 是一个包含一批尺寸为 (batch_size, *target_size, channels)的图像的 Numpy 数组,y 是对应标签的 Numpy 数组。

3.4 flow_from_dataframe()

功能: 输入dataframe和目录的路径,并生成批量的增强/标准化数据。

该函数的输入数据格式为Pandas dataframe

函数定义

def flow_from_dataframe(self, dataframe, directory=None,
                x_col="filename", y_col="class", weight_col=None,
                target_size=(256, 256), color_mode='rgb', classes=None,
                class_mode='categorical', batch_size=32, shuffle=True, seed=None,
                save_to_dir=None, save_prefix='', save_format='png', subset=None,
                interpolation='nearest', validate_filenames=True, **kwargs)
复制代码

参数:

  • dataframe: Pandas dataframe,一列为图像的文件名,另一列为图像的类别, 或者是可以作为原始目标数据多个列。
  • directory: 字符串,目标目录的路径,其中包含在 dataframe 中映射的所有图像。
  • x_col : 字符串,dataframe中包含目标图像文件夹的目录的列。
  • y_col: 字符串或字符串列表,dataframe中将作为目标数据的列。
  • has_ext: 布尔值,如果 dataframe[x_col] 中的文件名具有扩展名则为True,否则为 False
  • target_size: 整数元组(height, width),默认为 (256, 256)。所有找到的图都会调整到这个维度。
  • color_mode: "grayscale", "rbg" 之一。默认:"rgb"。 图像是否转换为 1 个或 3 个颜色通道。
  • classes: 可选的类别列表 (例如, ['dogs', 'cats'])。默认:None。 如未提供,类比列表将自动从 y_col 中推理出来,y_col将会被映射为类别索引)。 包含从类名到类索引的映射的字典可以通过属性class_indices获得。
  • class_mode: "categorical", "binary", "sparse", "input", "other" or None之一。 默认:"categorical"。决定返回标签数组的类型:
    • "categorical" 将是2D one-hot编码标签,
    • "binary" 将是 1D 二进制标签,
    • "sparse" 将是 1D 整数标签,
    • "input" 将是与输入图像相同的图像(主要用于与自动编码器一起使用),
    • "other" 将是y_col数据的numpy数组,None,不返回任何标签(生成器只会产生批量的图像数据,这对使用 model.predict_generator(), model.evaluate_generator() 等很有用)。
  • batch_size: 批量数据的尺寸(默认:32)。
  • shuffle: 是否混洗数据(默认:True)
  • seed: 可选的混洗和转换的随即种子。
  • save_to_dir: Nonestr (默认: None). 这允许你可选地指定要保存正在生成的增强图片的目录(用于可视化您正在执行的操作)。
  • save_prefix: 字符串。保存图片的文件名前缀(仅当 save_to_dir 设置时可用)。
  • save_format: "png","jpeg"之一(仅当save_to_dir设置时可用)。默认:"png"
  • follow_links: 是否跟随类子目录中的符号链接(默认:False)。
  • subset: 数据子集 ("training""validation"),如果在ImageDataGenerator 中设置了validation_split
  • interpolation: 在目标大小与加载图像的大小不同时,用于重新采样图像的插值方法。 支持的方法有"nearest", "bilinear", and "bicubic"。 如果安装了 1.1.3 以上版本的 PIL 的话,同样支持"lanczos"。 如果安装了 3.4.0 以上版本的 PIL 的话,同样支持 "box""hamming"。 默认情况下,使用"nearest"

返回值
一个生成(x, y) 元组的DataFrameIterator, 其中x是一个包含一批尺寸为 (batch_size, *target_size, channels)的图像样本的numpy数组,y 是对应的标签的 numpy 数组。

3.5 standardize()

此函数主要是对一组batch输入数据进行标准化处理。

主要步骤:

  • 如果preprocessing_function不为空,则执行该指定函数的处理x = self.preprocessing_function(x)
  • 如果rescaleTrue,则执行x*=self.rescale
  • 如果samplewise_centerTrue,则执行x-=np.mean(x, keepdims=True)去中心化

    计算的是当前一组batch数据的均值

  • 如果samplewise_std_normalizationTrue,则执行x /= (np.std(x, keepdims=True) + 1e-6)标准化

    计算的是当前一组batch数据的标准差

  • 如果featurewise_centerTrue,self.mean不为空,则执行x -= self.mean去中心化,否则给出警告
  • 如果featurewise_std_normalizationTrue,self.std不为空,则执行x /= (self.std + 1e-6)去中心化,否则给出警告
  • 如果zca_whiteningTrueself.principal_components不为空,则执行计算,否则给出警告

该函数的调用在_get_batches_of_transformed_samples()函数内,用来获取一组batch处理后的输入数据。

filepaths = self.filepaths
for i, j in enumerate(index_array):
    img = load_img(filepaths[j],
                   color_mode=self.color_mode,
                   target_size=self.target_size,
                   interpolation=self.interpolation)
    x = img_to_array(img, data_format=self.data_format)
    # Pillow images should be closed after `load_img`,
    # but not PIL images.
    if hasattr(img, 'close'):
        img.close()
    if self.image_data_generator:
        params = self.image_data_generator.get_random_transform(x.shape)
        x = self.image_data_generator.apply_transform(x, params)
        x = self.image_data_generator.standardize(x) # 执行标准化处理
    batch_x[i] = x
复制代码

越写越觉得有问题,如果featurewise_center设置为True,其调用了.fit()函数对数据进行了处理(此时self.mean已经被赋值),那么再执行到self.image_data_generator.standardize(x)的时候,岂不是又进行了一次去中心化的处理??

4. 具体使用

在使用Keras进行图像分类任务时,可以将训练数据按照以下结构进行保存:

datas/
    train/
        dogs/
            dog01.jpg
            dog02.jpg
            ...
        cats/
            cat01.jpg
            cat02.jpg
            ...
    validation/
        dogs/
            dog01.jpg
            dog02.jpg
            ...
        cats/
            cat01.jpg
            cat02.jpg
            ...
复制代码
  • 每个分类的图像存入一个文件夹中,按照训练集、验证集分开存放。
  • 调用flow_from_directory()函数时,数据标签值可自动根据数据子目录的名称/结构推断出来,每个子目录都被作为不同的类。因此,标签值可以不输入。

4.1 示例1

该示例不进行数据的标准化/去中心化/ZAC白化处理

如果需要进行数据增强,则按以下步骤:

# 调用ImageDataGenerator类,
train_datagen = ImageDataGenerator(
        rotation_range=30,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

# 通过flow_from_directory方法实例化一个生成器,它将不断循环生成batch数据
train_generator = train_datagen.flow_from_directory(
    './datas/train',
    target_size=(config.image_size, config.image_size),
    class_mode='categorical',
    batch_size=config.batch_size)
    
# 将生成器作为模型训练函数fit_generator的参数
model.fit_generator(train_generator,
        steps_per_epoch=nb_train_samples//config.batch_size + 1,
        epochs=config.epochs,
        validation_data=val_generator,
        validation_steps=nb_val_samples//config.batch_size + 1,
        callbacks=callbacks)
复制代码

从训练函数fit_generator的名称中也可以得出,它的输入是生成器对象。

同样,还可以应用于evaluate_generatorpredict_generator方法。

4.2 对数据进行去中心化/标准化

4.2.1 通过调用ImageDateGenerator.fit()函数实现

由前面知识可知,fit()函数的作用是将数据生成器用于示例数据。
当参数featurewise_centerfeaturewise_std_normalizationzca_whitening为设置为True时会对输入的数据x相应执行去中心化、规范化、ZCA白化处理。

以去中心化为例,具体处理如下:

  • 计算输入数据x的均值
  • 执行x-=self.mean
# ImageDataGenerator.fit()
if self.featurewise_center:
    self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))
    broadcast_shape = [1, 1, 1]
    broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
    self.mean = np.reshape(self.mean, broadcast_shape)
    x -= self.mean
复制代码

需要注意:均值是计算输入数据的均值。 需设置对应参数为True

如果不想使用输入数据自己的均值、标准差,则不能调用该函数,使用其他方法。

4.2.2 通过ImageDateGenerator.standardize()函数实现

该函数的作用是:将适当的规范化配置应用于一批输入数据

输入:输入将要被规范化处理的一组batch数据
返回:经过规范化处理后的输入数据

该函数对输入x是就地更改的,因为主要在用在内部对图像进行标准化并提供给网络,如果创建副本,会产生巨大的性能成本。

该函数对数据的处理包含两种情况:

  1. 参数featurewise_centerfeaturewise_std_normalization设为True,但是数据没有经过.fit()函数进行处理。
    此时,没有通过fit()函数求取数据的均值和标准差等数据,则数据的均值和标准差是为空的,需要指定。
    默认为None

    self.mean = None
    self.std = None
    复制代码

    处理如下:

    if self.featurewise_center:
        if self.mean is not None:
            x -= self.mean
        else:
            warnings.warn('This ImageDataGenerator specifies '
                          '`featurewise_center`, but it hasn\'t '
                          'been fit on any training data. Fit it '
                          'first by calling `.fit(numpy_data)`.')
    if self.featurewise_std_normalization:
        if self.std is not None:
            x /= (self.std + 1e-6)
        else:
            warnings.warn('This ImageDataGenerator specifies '
                          '`featurewise_std_normalization`, '
                          'but it hasn\'t '
                          'been fit on any training data. Fit it '
                          'first by calling `.fit(numpy_data)`.')
    复制代码

    需要手动设置数据的均值和标准差,才可以进行数据的处理,否则给出警告信息,正确使用方式如下:

    datagen = ImageDataGenerator(
                featurewise_center=True,
                rotation_range=30,
                shear_range=0.2,
                zoom_range=0.2)
    # 手动设置数据均值
    datagen.mean = np.array(config.data_mean, dtype=np.float32).reshape((1,1,3))
    train_generator = datagen.flow_from_directory(config.train_data,
                              target_size=img_size,
                              batch_size=batch_size,
                              class_mode=None,
                              shuffle=False)
    复制代码
  2. 如果数据没有经过.fit()函数进行处理,且不知道数据均值/标准差,则通过samplewise_centersamplewise_std_normalization参数也可以处理。

    该参数自动计算当前传入一组batch数据的均值及标准差

    # ImageDateGenerator()
    def standardize(self, x):
        if self.samplewise_center:
            x -= np.mean(x, keepdims=True)
        if self.samplewise_std_normalization:
            x /= (np.std(x, keepdims=True) + 1e-6)
    复制代码
  3. 通过preprocessing_function参数指定处理函数

    如果想把对数据处理的过程封装为一个单独的函数,则可以不使用上面介绍的方法(通过指定参数samplewise_centersamplewise_std_normalizationfeaturewise_centerfeaturewise_std_normalization

    如果self.preprocessing_function设置为处理函数,则先执行预处理函数。

    if self.preprocessing_function:
        x = self.preprocessing_function(x)
    复制代码

    调用方法

    # 指定处理 函数preprocessing_function
    train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input,
        rotation_range=30,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
    train_generator = train_datagen.flow_from_directory(
        './datas/train',
        target_size=(224,224),
        class_mode='categorical',
        batch_size=32)
    model.fit_generator(train_generator,
                steps_per_epoch=nb_train_samples//config.batch_size + 1,
                epochs=config.epochs,
                callbacks=callbacks)
    复制代码
分类:
人工智能
标签:
分类:
人工智能
标签: