1. TFRecords文件读取
TFRecords文件其实是一种二进制文件,可以用来存储机器学习中的数据。虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。比如,某些数据集中的特征值和标签往往是分开的,这个TFRecords文件优势就在于标签值与特征值是绑定在一起的。
基本操作流程:
- 获取数据
- 将数据填入Example协议内存块
- 将协议内存块序列化为字符串
2. 构造Example结构生成
想构造Example结构的数据格式,可以使用如下API,如下API需要相互嵌套使用(有些参数需要实例化),使用起来有点像“套娃”,第一个函数内部会嵌套第二个函数,第二个函数内部的参数又需要第三个函数下介绍的参数进行转化:
tf.train.Example(features=None)
- 写入tfrecords文件
- features:tf.train.Features类型的特征实例
- return:example格式协议块
tf.train.Features(feature=None)
- 用来生成Features实例特征,传入到上一个函数中的features参数
- 构建每个样本的信息键值对
- feature:是字典数据,key为要保存的名字
- value:为tf.train.Feature实例
- return:Features类型
tf.train.Feature(options)
- options:根据不同类型去填参数,如下所示
- bytes_list=tf.train.BytesList(value=[Bytes]):适用于字符串类型,ndarray类型也适用
- 我们常用的image图片都是像素组成的n维数组,属于ndarray类型,因此应该序列化成Bytes类型
- int64_list=tf.train.Int64List(value=[label]):适用于整型
- 对于CIFAR10数据集,其label为0-9这样的数字标签,可以归为整型进行序列化
- float_list=tf.FloatList(value=[value]):适用于浮点型
- bytes_list=tf.train.BytesList(value=[Bytes]):适用于字符串类型,ndarray类型也适用
上述API嵌套使用方式/格式:
example = tf.train.Example(features=tf.train.Features(feature={
"image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label])))
}))
上述代码中的example还是一个对象,若想要写入到本地,就需要我们对其进行实例化。
- 我们还需要使用example.SerializeToString()进行实例化
3. 案例演示:
下面我们将进行tfrecords数据存储案例演示,对CIFAR10数据集进行tfrecords文件存储,将读取到样本特征值与标签值数据写入tfrecords中。下面将从写入部分的函数构造来演示。
-
我们将数据写入到当前路径下的"cifar10.tfrecords"文件中
-
使用上下文管理器对其进行实例化
-
之前设定的batch_size为100,一次性读取100个样本。我们考虑使用for循环遍历100次,并序列化写入文件
-
每一次使用for循环拿一个数据
-
通过print直接打印查看一下输出
- image打印出来的是bytes类型的数据
- label打印出来的是1维数组类型,不直接属于之前所说的三种类型,我们需要将数组内部的值取出来(内部的值是整型,可以用int64存储)
-
读取到数据以后,就可以进行构造example
-
example构造完成以后比忘记还要对其进行序列化
def write_to_tfrecords(self, image_batch, label_batch):
"""
将样本的特征值和目标值一起写入tfrecords文件
"""
with tf.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
# 循环构造example对象,并序列化写入文件
for i in range(100):
image = image_batch[i].tobytes()
label = label_batch[i][0]
print("tf.records_image:\n", image)
print("tfrecords_label:\n", label)
example = tf.train.Example(features=tf.train.Features(feature={
"image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label])))
}))
# 将序列化后的example写入tfrecords文件
writer.write(example.SerializeToString)
return None
writer_to_tfrecords(image_value, label_value)
在读取数据过程完成后,使用上述代码,可以进行文件存储,在当前目录下生成了一个二进制tfrecords文件。
本文正在参加「金石计划 . 瓜分6万现金大奖」