【深度学习】TensorFlow:tfrecords文件存储

368 阅读3分钟

1. TFRecords文件读取

TFRecords文件其实是一种二进制文件,可以用来存储机器学习中的数据。虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。比如,某些数据集中的特征值和标签往往是分开的,这个TFRecords文件优势就在于标签值与特征值是绑定在一起的。

基本操作流程:

  1. 获取数据
  2. 将数据填入Example协议内存块
  3. 将协议内存块序列化为字符串

2. 构造Example结构生成

想构造Example结构的数据格式,可以使用如下API,如下API需要相互嵌套使用(有些参数需要实例化),使用起来有点像“套娃”,第一个函数内部会嵌套第二个函数,第二个函数内部的参数又需要第三个函数下介绍的参数进行转化:

tf.train.Example(features=None)

  • 写入tfrecords文件
  • features:tf.train.Features类型的特征实例
  • return:example格式协议块

tf.train.Features(feature=None)

  • 用来生成Features实例特征,传入到上一个函数中的features参数
  • 构建每个样本的信息键值对
  • feature:是字典数据,key为要保存的名字
  • value:为tf.train.Feature实例
  • return:Features类型

tf.train.Feature(options)

  • options:根据不同类型去填参数,如下所示
    • bytes_list=tf.train.BytesList(value=[Bytes]):适用于字符串类型,ndarray类型也适用
      • 我们常用的image图片都是像素组成的n维数组,属于ndarray类型,因此应该序列化成Bytes类型
    • int64_list=tf.train.Int64List(value=[label]):适用于整型
      • 对于CIFAR10数据集,其label为0-9这样的数字标签,可以归为整型进行序列化
    • float_list=tf.FloatList(value=[value]):适用于浮点型

上述API嵌套使用方式/格式:

example = tf.train.Example(features=tf.train.Features(feature={
"image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label])))
}))

上述代码中的example还是一个对象,若想要写入到本地,就需要我们对其进行实例化。

  • 我们还需要使用example.SerializeToString()进行实例化

3. 案例演示:

下面我们将进行tfrecords数据存储案例演示,对CIFAR10数据集进行tfrecords文件存储,将读取到样本特征值与标签值数据写入tfrecords中。下面将从写入部分的函数构造来演示。

  • 我们将数据写入到当前路径下的"cifar10.tfrecords"文件中

  • 使用上下文管理器对其进行实例化

  • 之前设定的batch_size为100,一次性读取100个样本。我们考虑使用for循环遍历100次,并序列化写入文件

  • 每一次使用for循环拿一个数据

  • 通过print直接打印查看一下输出

    • image打印出来的是bytes类型的数据
    • label打印出来的是1维数组类型,不直接属于之前所说的三种类型,我们需要将数组内部的值取出来(内部的值是整型,可以用int64存储)
  • 读取到数据以后,就可以进行构造example

  • example构造完成以后比忘记还要对其进行序列化

def write_to_tfrecords(self, image_batch, label_batch):
    """
    将样本的特征值和目标值一起写入tfrecords文件
    """
    with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
        # 循环构造example对象,并序列化写入文件
        for i in range(100):
            image = image_batch[i].tobytes()
            label = label_batch[i][0]
            print("tf.records_image:\n", image)
            print("tfrecords_label:\n", label)
            example = tf.train.Example(features=tf.train.Features(feature={
                "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])
                "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label])))
            }))

            # 将序列化后的example写入tfrecords文件
            writer.write(example.SerializeToString)

            
    return None

writer_to_tfrecords(image_value, label_value)

在读取数据过程完成后,使用上述代码,可以进行文件存储,在当前目录下生成了一个二进制tfrecords文件。

本文正在参加「金石计划 . 瓜分6万现金大奖」