Tensorflow dataset的使用(下)

425 阅读4分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第1天,点击查看活动详情

tfrecord基础API的使用

tfrecord 是一个 文件格式

里面存储的内容都是tf.train.Example

tf.train.Example里面都是tf.train.Features -> {"key": tf.train.Feature}

tf.train.Feature 有不同的格式 包括tf.train.ByteList(存储字符串)/FloatList(存储浮点数)/Int64List(存储整数)

定义一个Features的数据

favorite_books = [name.encode('utf-8')
                  for name in ["machine learning", "cc150"]]
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print(favorite_books_bytelist)

hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print(hours_floatlist)

age_int64list = tf.train.Int64List(value = [42])
print(age_int64list)

features = tf.train.Features(
    feature = {
        "favorite_books": tf.train.Feature(
            bytes_list = favorite_books_bytelist),
        "hours": tf.train.Feature(
            float_list = hours_floatlist),
        "age": tf.train.Feature(int64_list = age_int64list),
    }
)
print(features)

运行结果:

value: "machine learning"
value: "cc150"

value: 15.5
value: 9.5
value: 7.0
value: 8.0

value: 42

feature {
  key: "age"
  value {
    int64_list {
      value: 42
    }
  }
}
feature {
  key: "favorite_books"
  value {
    bytes_list {
      value: "machine learning"
      value: "cc150"
    }
  }
}
feature {
  key: "hours"
  value {
    float_list {
      value: 15.5
      value: 9.5
      value: 7.0
      value: 8.0
    }
  }
}

定义一个Example的数据,并对它进行序列化

example = tf.train.Example(features=features)
print(example)

serialized_example = example.SerializeToString()
print(serialized_example)

运行结果:

features {
  feature {
    key: "age"
    value {
      int64_list {
        value: 42
      }
    }
  }
  feature {
    key: "favorite_books"
    value {
      bytes_list {
        value: "machine learning"
        value: "cc150"
      }
    }
  }
  feature {
    key: "hours"
    value {
      float_list {
        value: 15.5
        value: 9.5
        value: 7.0
        value: 8.0
      }
    }
  }
}

b'\n\\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*\n-\n\x0efavorite_books\x12\x1b\n\x19\n\x10machine learning\n\x05cc150\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10\x00\x00xA\x00\x00\x18A\x00\x00\xe0@\x00\x00\x00A'

把Example存入文件中去,生成一个具体的tfrecord的文件

output_dir = 'tfrecord_basic'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as writer:
    for i in range(3):
        writer.write(serialized_example)

去读取tfrecord文件

dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    print(serialized_example_tensor)

对序列化后的example进行解析

expected_features = {
    "favorite_books": tf.io.VarLenFeature(dtype = tf.string),  # 变长的字符串
    "hours": tf.io.VarLenFeature(dtype = tf.float32),   # 变长的数据
    "age": tf.io.FixedLenFeature([], dtype = tf.int64),  # 定长的数据
}
dataset = tf.data.TFRecordDataset([filename_fullpath])
# 对数据进行解析
for serialized_example_tensor in dataset:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))

运行结果:

machine learning
cc150
machine learning
cc150
machine learning
cc150

除了正常的存储,还可以将tfrecord文件存储成压缩文件

filename_fullpath_zip = filename_fullpath + '.zip'
options = tf.io.TFRecordOptions(compression_type = "GZIP")
with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:
    for i in range(3):
        writer.write(serialized_example)

没有压缩的文件时330B

压缩文件的大小时127B

读取压缩后的文件:只需要再读取的时候加上 compression_type= "GZIP"

dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip], 
                                      compression_type= "GZIP")
for serialized_example_tensor in dataset_zip:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))

运行结果:

machine learning
cc150
machine learning
cc150
machine learning
cc150

生成tfrecords文件

使用之前生成的csv文件,获取它的文件名

source_dir = "./generate_csv/"

def get_filenames_by_prefix(source_dir, prefix_name):
    all_files = os.listdir(source_dir)
    results = []
    for filename in all_files:
        if filename.startswith(prefix_name):
            results.append(os.path.join(source_dir, filename))
    return results

train_filenames = get_filenames_by_prefix(source_dir, "train")
valid_filenames = get_filenames_by_prefix(source_dir, "valid")
test_filenames = get_filenames_by_prefix(source_dir, "test")

import pprint
pprint.pprint(train_filenames)
pprint.pprint(valid_filenames)
pprint.pprint(test_filenames)

运行结果:

['./generate_csv/train_01.csv', './generate_csv/train_02.csv', './generate_csv/train_07.csv', './generate_csv/train_08.csv', './generate_csv/train_11.csv', './generate_csv/train_15.csv', './generate_csv/train_04.csv', './generate_csv/train_14.csv', './generate_csv/train_05.csv', './generate_csv/train_09.csv', './generate_csv/train_18.csv', './generate_csv/train_06.csv', './generate_csv/train_03.csv', './generate_csv/train_16.csv', './generate_csv/train_13.csv', './generate_csv/train_19.csv', './generate_csv/train_00.csv', './generate_csv/train_12.csv', './generate_csv/train_10.csv', './generate_csv/train_17.csv']
['./generate_csv/valid_02.csv', './generate_csv/valid_08.csv', './generate_csv/valid_00.csv', './generate_csv/valid_04.csv', './generate_csv/valid_07.csv', './generate_csv/valid_01.csv', './generate_csv/valid_09.csv', './generate_csv/valid_06.csv', './generate_csv/valid_05.csv', './generate_csv/valid_03.csv']
['./generate_csv/test_07.csv', './generate_csv/test_08.csv', './generate_csv/test_03.csv', './generate_csv/test_04.csv', './generate_csv/test_00.csv', './generate_csv/test_06.csv', './generate_csv/test_05.csv', './generate_csv/test_01.csv', './generate_csv/test_02.csv', './generate_csv/test_09.csv']

读取csv文件(具体方法见Tensorflow dataset的使用(上)

def parse_csv_line(line, n_fields = 9):
    defs = [tf.constant(np.nan)] * n_fields
    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parsed_fields[0:-1])
    y = tf.stack(parsed_fields[-1:])
    return x, y

def csv_reader_dataset(filenames, n_readers=5,
                       batch_size=32, n_parse_threads=5,
                       shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_csv_line,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

batch_size = 32
train_set = csv_reader_dataset(train_filenames,
                               batch_size = batch_size)
valid_set = csv_reader_dataset(valid_filenames,
                               batch_size = batch_size)
test_set = csv_reader_dataset(test_filenames,
                              batch_size = batch_size)

对从dataset中获取到的样本转成Example,并对它进行序列化

def serialize_example(x, y):
    """Converts x, y to tf.train.Example and serialize"""
    input_feautres = tf.train.FloatList(value = x)
    label = tf.train.FloatList(value = y)
    features = tf.train.Features(
        feature = {
            "input_features": tf.train.Feature(
                float_list = input_feautres),
            "label": tf.train.Feature(float_list = label)
        }
    )
    example = tf.train.Example(features = features)
    return example.SerializeToString()

再将得到的序列化数据写入文件中

def csv_dataset_to_tfrecords(base_filename, dataset,
                             n_shards, steps_per_shard,
                             compression_type = None):
    options = tf.io.TFRecordOptions(
        compression_type = compression_type)
    all_filenames = []
    
    for shard_id in range(n_shards):
        filename_fullpath = '{}_{:05d}-of-{:05d}'.format(
            base_filename, shard_id, n_shards)
        with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
            for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):
                for x_example, y_example in zip(x_batch, y_batch):
                    writer.write(
                        serialize_example(x_example, y_example))
        all_filenames.append(filename_fullpath)
    return all_filenames

调用刚刚生成的函数,生成tfrecords文件:

n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

output_dir = "generate_tfrecords"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard, None)
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_basename, valid_set, n_shards, valid_steps_per_shard, None)
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
    test_basename, test_set, n_shards, test_steps_per_shard, None)

同时也可以生成压缩后的文件(只需要改一下文件名称,并且修改一下参数类型即可compression_type = "GZIP")

n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

output_dir = "generate_tfrecords_zip"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard,
    compression_type = "GZIP")
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_basename, valid_set, n_shards, valid_steps_per_shard,
    compression_type = "GZIP")
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
    test_basename, test_set, n_shards, test_steps_per_shard,
    compression_type = "GZIP")

未压缩的测试文件每一个是23.8KB

压缩后的测试文件每一个是10.2KB

读取tfrecord文件并与tf.keras结合使用

读取tfrecord文件

expected_features = {
    "input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}

# 对每一个样本进行处理
def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example,
                                         expected_features)
    return example["input_features"], example["label"]

# 从完整的文件名到文件列表的转换
def tfrecords_reader_dataset(filenames, n_readers=5,
                             batch_size=32, n_parse_threads=5,
                             shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(
            filename, compression_type = "GZIP"),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_example,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,
                                           batch_size = 3)
for x_batch, y_batch in tfrecords_train.take(1):
    print(x_batch)
    print(y_batch)

运行结果:

tf.Tensor(
[[ 0.27419487 -1.0895426   0.27506563 -0.1635966   0.40234062 -0.03872543
   1.5657866  -0.7864974 ]
 [-0.7880703  -1.8104447  -0.28291935  0.11018982  0.9240302  -0.13634782
  -0.5291081   1.1160139 ]
 [-0.17967628  0.6726626  -0.60395646 -0.08978043 -0.4759124   0.05870318
   0.95924693 -1.4406418 ]], shape=(3, 8), dtype=float32)
tf.Tensor(
[[2.206]
 [1.031]
 [2.386]], shape=(3, 1), dtype=float32)

对数据进行模型训练

batch_size = 32
tfrecords_train_set = tfrecords_reader_dataset(
    train_tfrecord_filenames, batch_size = batch_size)
tfrecords_valid_set = tfrecords_reader_dataset(
    valid_tfrecord_filenames, batch_size = batch_size)
tfrecords_test_set = tfrecords_reader_dataset(
    test_tfrecord_fielnames, batch_size = batch_size)
model = keras.models.Sequential([
    keras.layers.Dense(30, activation='relu',
                       input_shape=[8]),
    keras.layers.Dense(1),
])
model.compile(loss="mean_squared_error", optimizer="sgd")
callbacks = [keras.callbacks.EarlyStopping(
    patience=5, min_delta=1e-2)]

history = model.fit(tfrecords_train_set,
                    validation_data = tfrecords_valid_set,
                    steps_per_epoch = 11160 // batch_size,
                    validation_steps = 3870 // batch_size,
                    epochs = 100,
                    callbacks = callbacks)

运行结果:

Epoch 1/100
348/348 [==============================] - 1s 2ms/step - loss: 0.9537 - val_loss: 0.5998
Epoch 2/100
348/348 [==============================] - 1s 2ms/step - loss: 0.4812 - val_loss: 0.5260
Epoch 3/100
348/348 [==============================] - 1s 2ms/step - loss: 0.5582 - val_loss: 0.5030
Epoch 4/100
348/348 [==============================] - 1s 2ms/step - loss: 0.4256 - val_loss: 0.4746
Epoch 5/100
348/348 [==============================] - 1s 2ms/step - loss: 0.4052 - val_loss: 0.4668
Epoch 6/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3970 - val_loss: 0.4504
Epoch 7/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3889 - val_loss: 0.4394
Epoch 8/100
348/348 [==============================] - 1s 2ms/step - loss: 0.4483 - val_loss: 0.4602
Epoch 9/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3961 - val_loss: 0.4377
Epoch 10/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3725 - val_loss: 0.4281
Epoch 11/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3710 - val_loss: 0.4226
Epoch 12/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3712 - val_loss: 0.4149
Epoch 13/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3520 - val_loss: 0.4102
Epoch 14/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3626 - val_loss: 0.4126
Epoch 15/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3446 - val_loss: 0.4111
Epoch 16/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3628 - val_loss: 0.4012
Epoch 17/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3992 - val_loss: 0.4139
Epoch 18/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3575 - val_loss: 0.4093
Epoch 19/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3651 - val_loss: 0.4052
Epoch 20/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3486 - val_loss: 0.3981
Epoch 21/100
348/348 [==============================] - 1s 2ms/step - loss: 0.3662 - val_loss: 0.4002
model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)

运行结果:

161/161 [==============================] - 0s 1ms/step - loss: 0.3830
0.38302409648895264