tensorflow生成tfrecord格式的数据

710 阅读1分钟

tensorflow生成tfrecord格式的数据

tfrecord格式数据能高效的组织数据,提高训练时的IO性能
1,2步骤定义了函数,3步骤生成tfrecord格式的数据 1.TF-Feature 将数据(values)封装于tf.train.Feature

def int64_feature(values):
  """Returns a TF-Feature of int64s.
  Args:
    values: A scalar or list of values.
  Returns:
    A TF-Feature.
  """
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def bytes_feature(values):
  """Returns a TF-Feature of bytes.
  Args:
    values: A string.
  Returns:
    A TF-Feature.
  """
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def float_feature(values):
  """Returns a TF-Feature of floats.
  Args:
    values: A scalar of list of values.
  Returns:
    A TF-Feature.
  """
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(float_list=tf.train.FloatList(value=values))

2.tf.train.Example

def create_tf_example(features_dict):
    '''
    :param features_dict: { img_query_bytes:bytes
    :return: tf.train.Example
    '''
    feature_map={  'img_query':bytes_feature(features_dict['img_query_bytes'])
                  }
    return tf.train.Example(features=tf.train.Features(feature=feature_map))

3.写入tfrecord文件 使用tf.python_io.TFRecordWriter(out_path) 写入 tf.train.Example

    out_path = './dataset/train.record'
    with tf.python_io.TFRecordWriter(out_path) as writer:
            features_dict=dict()
            with tf.gfile.GFile(img_path,'rb') as fid:
                features_dict['img_query_bytes']=fid.read()

            example=create_tf_example(features_dict)
            writer.write(example.SerializeToString())
            if iter_num%1000==0:
                print('done : {} % {}'.format(iter_num,iter_steps))

一个Example结构

dict 表示字典类型

tf.train.Example {
    features: tf.train.Features{
        feature: dict{
            'feature_name':tf.train.Feature{
                int64_list:tf.train.Int64List{value:list}
                bytes_list:tf.train.BytesList{value:list}
                float_list:tf.train.FloatList{value:list}
            }
        }
    }
}