TFRecords数据

918 阅读4分钟
TFRecords是tensorflow官网提供的一种二进制文件,它能方便的进行数据复制 移动 和更好的利用内存,同时不需要单独的标签文件(在读取数据文件是自动添加标签,下面有介绍);在训练时,使用TFRecords中数据的流程:首先生成xxx.tfrecord文件,接着使用input pipeline读取xxx.tfrecords文件/其他支持格式,then读取并解码数据,随机乱序(shuffle),生成文件序列(batch);最后输入到模型中。
如果有一串jpg图片地址和相应的标签:ImagesLabels
1. 生成TFRecords
存入TFRecords文件需要数据先存入名为example的protocol buffer,然后将其serialize成为string才能写入。example中包含features,用于描述数据类型:bytes,float,int64;具体来说,TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。


  • message Example {



  • Features features = 1;



  • };



  • message Features {



  • map<string, Feature> feature = 1;



  • };



  • message Feature {



  • oneof kind {



  • BytesList bytes_list = 1;



  • FloatList float_list = 2;



  • Int64List int64_list = 3;



  • }



  • };





  • # -*- coding: utf-8 -*-



  • import os



  • import tensorflow as tf



  • from PIL import Image



  • import matplotlib.pyplot as plt



  • import numpy as np







  • cwd = "E:/Anaconda3/tensorflow/Dataset/data/"



  • classes = {'cats', 'dogs'} #预先自己定义的类别



  • #将数据转化TFRecord文件对应的属性



  • def _int64_feature(value):



  • return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))







  • def _bytes_feature(value):



  • return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))



  • # 开始将数据写入TFRecord文件(xxx.tfrecord)



  • train_filename = 'tensorflow/train.tfrecords' # 输出文件地址



  • # 创建一个writer来写TFRecords文件(写TFRecords <==> 输出TFRecords文件)



  • writer = tf.python_io.TFRecordWriter(train_filename) #输出成tfrecord文件











  • for index, name in enumerate(classes): # 从classes中自动获取类别 (label)



  • class_path = cwd + name + '//'



  • for img_name in os.listdir(class_path):



  • img_path = class_path + img_name #每张picture的绝对地址







  • img = Image.open(img_path)



  • img = img.resize((640, 320))



  • img_raw = img.tobytes() #将图片转化为二进制格式



  • # 创建一个属性(feature)



  • example = tf.train.Example(features = tf.train.Features(feature = {



  • "label":_int64_feature(index),



  • "img_raw":_bytes_feature(img_raw),



  • }))



  • # 将上面的example protocol buffer 写入文件



  • writer.write(example.SerializeToString()) #序列化为字符串



  • writer.close()



输入: 数据文件路径 path
输出: xxx.tfrecords文件
reference:blog.csdn.net/hjxu2016/ar…
2. 读取TFRecord 文件
(1). 用tf.train.string_input_producer 读取tfrecords文件(xxx.tfrecords)的list建立文件名队列(FIFO序列),同时,可以申明num_epoches和shuffle参数表示需要读取数据的次数以及时候将tfrecords文件读入顺序打乱;结果:图像路径list
(2). 定义TFRecordReader读取(1)中的序列(图像路径list)返回下一个record;结果:serialize example和feature字典
(3). 用tf.parse_string_example对读取的TFRecords文件进行解码,抽取((2) serialize example和feature字典)中,返回feature对应的值,此时对应的值都是string,需要经过tf.decode(...) 和 tf.cast(...)等操作,将string类型的图像数据还原原始图像;同时也可以进行一些preprocessing操作;
(4). 利用tf.train.shuffle_batch(...)和tf.train.batch(...)将(3)中还原原始图像生成batch图像序列


  • #读取文件



  • def read_and_decode(filename,batch_size):



  • #根据文件名生成一个队列



  • filename_queue = tf.train.string_input_producer([filename])



  • reader = tf.TFRecordReader()



  • _, serialized_example = reader.read(filename_queue) #返回文件名和文件



  • features = tf.parse_single_example(serialized_example,



  • features={



  • 'label': tf.FixedLenFeature([], tf.int64),



  • 'img_raw' : tf.FixedLenFeature([], tf.string),



  • })







  • img = tf.decode_raw(features['img_raw'], tf.uint8)



  • img = tf.reshape(img, [300, 300, 3]) #图像归一化大小



  • # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #图像减去均值处理,根据自己的需要决定要不要加上



  • label = tf.cast(features['label'], tf.int32)







  • #特殊处理,去数据的batch,如果不要对数据做batch处理,也可以把下面这部分不放在函数里







  • img_batch, label_batch = tf.train.shuffle_batch([img, label],



  • batch_size= batch_size,



  • num_threads=64,



  • capacity=200,



  • min_after_dequeue=150)



  • return img_batch, tf.reshape(label_batch,[batch_size])



在读取到队列中后,数据输出之前还要作解码的操作从,可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量;
输入:XXX.tfrecords batch_size
输出: image_batch label_batch
3. 扩展
由于tf.train()函数在graph中增加了tf.train.QueueRunner类(在线程中运行线程中的队列数据),tf.train.start_queue_runner启动所有graph中的线程;用tf.train.Coordinator来管理线程(启动多少线程 何时终止线程...)


  • # initialize global & local variables



  • init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())



  • sess.run(init_op)



  • # create a coordinate and run queue runner objects



  • # 启动多线程处理数据



  • coord = tf.train.Coordinator()



  • threads = tf.train.start_queue_runners(coord=coord)



  • for batch_index in range(3):



  • batch_images, batch_labels = sess.run([images, labels])



  • for i in range(10):



  • plt.imshow(batch_images[i, ...])



  • plt.show()



  • print "Current image label is: ", batch_lables



  • # close threads 结束线程



  • coord.request_stop()



  • coord.join(threads)



  • sess.close()



4. 如何显示xxx.tfrecords文件中的图片


  • tfrecords_file = 'E:/Anaconda3/tensorflow//dataset/train.tfrecords'



  • Batch_size = 6



  • image_batch, label_batch = read_and_decode(tfrecords_file,Batch_size)







  • with tf.Session() as sess:







  • i = 0



  • coord = tf.train.Coordinator()



  • threads = tf.train.start_queue_runners(coord=coord)







  • try:



  • while not coord.should_stop() and i<1:



  • # just plot one batch size



  • image, label = sess.run([image_batch, label_batch])



  • for j in np.arange(4):



  • print('label: %d' % label[j])



  • plt.imshow(image[j,:,:,:])



  • plt.show()



  • i+=1



  • except tf.errors.OutOfRangeError:



  • print('done!')



  • finally:



  • coord.request_stop()



  • coord.join(threads)



batch_size这里可以大家任意设定,显示几幅图片都可以,这里设置为6 同时i 控制显示张数
5. 完整代码


  • # -*- coding: utf-8 -*-



  • import os



  • import tensorflow as tf



  • from PIL import Image



  • import matplotlib.pyplot as plt



  • import numpy as np







  • cwd = "E:/Anaconda3/tensorflow/dataset/data/"



  • classes = {'cats', 'dogs'}



  • writer = tf.python_io.TFRecordWriter('train.tfrecords')







  • def _int64_feature(value):



  • return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))







  • def _bytes_feature(value):



  • return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))







  • for index, name in enumerate(classes):



  • class_path = cwd + name + '//'



  • for img_name in os.listdir(class_path):



  • img_path = class_path + img_name #每张图片的绝对地址







  • img = Image.open(img_path)



  • img = img.resize((640, 320))



  • img_raw = img.tobytes() #将图片转化为二进制格式



  • example = tf.train.Example(features = tf.train.Features(feature = {



  • "label":_int64_feature(index),



  • "img_raw":_bytes_feature(img_raw),



  • }))



  • writer.write(example.SerializeToString()) #序列化为字符串



  • writer.close()







  • def read_and_decode(filename, batch_size): # read train.tfrecords



  • filename_queue = tf.train.string_input_producer([filename])# create a queue







  • reader = tf.TFRecordReader()



  • _, serialized_example = reader.read(filename_queue)#return file_name and file



  • features = tf.parse_single_example(serialized_example,



  • features={



  • 'label': tf.FixedLenFeature([], tf.int64),



  • 'img_raw' :tf.FixedLenFeature([],tf.string),



  • })#return image and label







  • img = tf.decode_raw(features['img_raw'], tf.uint8)



  • img = tf.reshape(img, [208, 208, 3]) #reshape image to 512*80*3







  • label = tf.cast(features['label'], tf.int32) #throw label tensor







  • img_batch, label_batch = tf.train.shuffle_batch([img, label],



  • batch_size= batch_size,



  • num_threads=64,



  • capacity=2000,



  • min_after_dequeue=1500,



  • )



  • return img_batch, tf.reshape(label_batch,[batch_size])











  • tfrecords_file = 'D:/Anaconda3/tensorflow/dataset/train.tfrecords'



  • Batch_size = 6



  • image_batch, label_batch = read_and_decode(tfrecords_file, Batch_size)







  • with tf.Session() as sess:







  • i = 0



  • coord = tf.train.Coordinator()



  • threads = tf.train.start_queue_runners(coord=coord)







  • try:



  • while not coord.should_stop() and i<1:



  • # just plot one batch size



  • image, label = sess.run([image_batch, label_batch])



  • for j in np.arange(BATCH_SIZE):



  • print('label: %d' % label[j])



  • plt.imshow(image[j,:,:,:])



  • plt.show()



  • i+=1



  • except tf.errors.OutOfRangeError:



  • print('done!')



  • finally:



  • coord.request_stop()



  • coord.join(threads)




---------------------作者:千里江河 来源:CSDN 原文:blog.csdn.net/qinghange/ … 版权声明:本文为博主原创文章,转载请附上博文链接!

更多免费技术资料可关注:annalin1203