开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第九天,点击查看活动详情
总结:此文为12月更文计划第九天第十四篇。
dataset的学习
今天开始dataset的学习:
内存中构建dataset,初始化dataset,返回的像一个迭代器
dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
print(dataset)
调用Dataset.from_tensor_slices接口可以将数据转化为tensor类型。
可以做遍历,每一个都是tensor
for item in dataset:
print(item)
输出的结果如下:可以看到每一个都是Tensor类型。
为了epoch服务的,repeat(3) 表示
dataset1 = dataset.repeat(3)
for item in dataset1:
print(item)
表示可以将数据进行3次循环,可以在做训练集的时候使用:
对于数据,可以将数据的形状进行划分,调用batch方法:
dataset = dataset.repeat(3).batch(7)
print(dataset)
i=0
for item in dataset:
i=i+1
print(item)
输出的结果如下:
此刻可以看到数据被划分为的形状:
for i in dataset:
print(i)
输出的结果如下:
以上就是对数据集在训练过程中进行的一些简单操作,可以在训练数据集的过程中按照以上的方法使用。
dataset.interleave
dataset.interleave的接口。 interleave()是Dataset的类方法,所以interleave是作用在一个Dataset上的。
首先该方法会从该Dataset中取出cycle_length个element,然后对这些element apply map_func, 得到cycle_length个新的Dataset对象。然后从这些新生成的Dataset对象中取数据,每个Dataset对象一次取block_length个数据。当新生成的某个Dataset的对象取尽时,从原Dataset中再取一个element,然后apply map_func,以此类推。
dataset2 = dataset.interleave(
lambda v: tf.data.Dataset.from_tensor_slices(v), # map_fn,第一参数是回调函数
cycle_length = 5, # cycle_length,每一个cycle提取的个数
block_length = 6, # block_length
)
print(dataset2)
print('-'*50)
i=0
for item in dataset2:
i=i+1
print(item)
输出的结果如下:
这是另一个例子:
cycle_length是使用了几个block以后,就要重复的 参数
a = tf.data.Dataset.range(1,6)
# b=a.repeat(6)
# for i in b:
# print(i)
print('-'*50)
a1=a.interleave(lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
cycle_length=3, block_length=4)
for i in a1:
print(i)
输出的结果如下: