手动训练模型(dataset)

113 阅读2分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第九天,点击查看活动详情

总结:此文为12月更文计划第九天第十四篇。

dataset的学习

今天开始dataset的学习:

内存中构建dataset,初始化dataset,返回的像一个迭代器

dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
print(dataset)

调用Dataset.from_tensor_slices接口可以将数据转化为tensor类型。

可以做遍历,每一个都是tensor

for item in dataset:
    print(item)

输出的结果如下:可以看到每一个都是Tensor类型。 image.png

为了epoch服务的,repeat(3) 表示

dataset1 = dataset.repeat(3) 
for item in dataset1:
    print(item)

表示可以将数据进行3次循环,可以在做训练集的时候使用:

image.png

对于数据,可以将数据的形状进行划分,调用batch方法:

dataset = dataset.repeat(3).batch(7)
print(dataset)
i=0
for item in dataset:
    i=i+1
    print(item)

输出的结果如下:

image.png

此刻可以看到数据被划分为的形状:

for i in dataset:
    print(i)

输出的结果如下:

image.png

以上就是对数据集在训练过程中进行的一些简单操作,可以在训练数据集的过程中按照以上的方法使用。

dataset.interleave

dataset.interleave的接口。 interleave()是Dataset的类方法,所以interleave是作用在一个Dataset上的。

首先该方法会从该Dataset中取出cycle_length个element,然后对这些element apply map_func, 得到cycle_length个新的Dataset对象。然后从这些新生成的Dataset对象中取数据,每个Dataset对象一次取block_length个数据。当新生成的某个Dataset的对象取尽时,从原Dataset中再取一个element,然后apply map_func,以此类推。

dataset2 = dataset.interleave(
    lambda v: tf.data.Dataset.from_tensor_slices(v), # map_fn,第一参数是回调函数
    cycle_length = 5, # cycle_length,每一个cycle提取的个数
    block_length = 6, # block_length
)
print(dataset2)
print('-'*50)
i=0
for item in dataset2:
    i=i+1
    print(item)

输出的结果如下:

image.png

这是另一个例子:

cycle_length是使用了几个block以后,就要重复的 参数

a = tf.data.Dataset.range(1,6) 
# b=a.repeat(6)
# for i in b:
#     print(i)
print('-'*50)
a1=a.interleave(lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
            cycle_length=3, block_length=4)
for i in a1:
    print(i)

输出的结果如下:

image.png