深度学习1 基于h5py使用数据迭代器训练超过内存的数据

1,525 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路

背景

在使用keras训练图像数据时,通常使用ImageDataGenerator()的方式迭代目录,分批次读取。而对于容量超大甚至超过内存数组文件(npy文件或者mat文件)时无法通过文件迭代器来分批次读取,因此提出该方法。

解决方法

1、安装h5py库 使用该命令:pip install h5py

2、将数组文件载入内容重新保存为h5py文件。分批次将数组存为h5py的方法可以参考blog.csdn.net/a486259/art…

3、构建迭代器,使用h5py['target'][index:index+batch_size]的方法不断迭代数据

4、使用迭代的方法训练模型

原理说明

当使用h5py['target'].shape、h5py['target'].[index:index+batch_size]等命令时,并不会将h5py['target']所对应的全部数据加载到内存中。因此,可以使用h5py[index:index+batch_size]获取数据,h5py只会将相应数据index位置起的batch_size条记录返回。虽然数据是逐次加载,但是显著的降低了对内存的需求。

操作案例

1、将目标数据转存为h5py文件 (代码中的save_h5方法来自于blog.csdn.net/a486259/art…

我这里数据用data表示,标签用i_labe表示

2、构造数据迭代器

import h5py
from sklearn.model_selection import train_test_split
from keras.utils.np_utils import to_categorical
def data_gen(dtype='train',rate=0.2,batch_size=500):
    file_name='all_data.h5'  #刚刚存储的文件
    h5f=h5py.File(file_name)
    while True:
        for index in range(0,h5f['data'].shape[0],batch_size):
            data=h5f['data'][index:index+batch_size]  #读取数据
            i_label=h5f['i_label'][index:index+batch_size]  #读取标签
            i_label = to_categorical(i_label, num_classes=hz_nums)  #将标签转化为独热码的形式
            X_train, X_test, y_train, y_test =train_test_split(data,i_label,train_size=rate, test_size=1-rate, random_state=42) #将数据划分为训练集和测试集
            if dtype=='train':   #按照不同的需求返回数据
                yield (X_train, y_train)
            else:
                yield (X_test, y_test)

3、使用迭代器训练数据

其中,steps_per_epoch表示需要迭代多少次才能获取一个epoch的训练数据,validation_steps表示需要迭代多少次才能获取一个epoch的验证数据

因为我的数据是在迭代时才划分的,所以steps_per_epoch跟validation_steps相等

file_name='all_data.h5'
h5f=h5py.File(file_name)
all_data_num=h5f['data'].shape[0] #获取全部数据的量
rate=0.8   #训练集划分的比例
batch_size=1000    #每次取数据的batch_size
model.fit_generator(data_gen(dtype='train',rate=rate,batch_size=batch_size),
                    steps_per_epoch=all_data_num// batch_size,
                    epochs=100,
                    validation_data=(data_gen(dtype='test',rate=rate,batch_size=batch_size)),
                    validation_steps=all_data_num// batch_size,
                    callbacks=callbacks,
                    shuffle=True,
                    verbose=1)

4、使用迭代器的数据评估模型

loss,acc=model.evaluate_generator(data_gen(dtype='test',rate=rate,batch_size=batch_size),steps=all_da