本文已参与「新人创作礼」活动,一起开启掘金创作之路。
tensorflow2.0版本之后相比于1.0版本更方便的一个地方是引入了tensorflow数据集,它可以以tf.data和numpy的形式加载进来。datasets的产生,使得大家对于解决某些数据集的繁琐工作量大大减少,通过td.data API可以构建更高效的输入。
对于datasets的安装方法如下:
pip install tensorflow-datasets
安装时大概如下面的图所示。
安装完成后下一步就是直接上手测试,首先查看一下它包含了哪些数据集,下次需要时可以直接导入。
import tensorflow as tf
import tensorflow_datasets as tfdt
# 查看包含的数据集
data_list = tfdt.list_builders()
print(data_list)
print(len(data_list))
结果如下
从打印的结果来看,包含的数据集的数量是真的不少,有310种,这对于学习来说足够用了。就拿上一篇博客提到的MNIST数据集来说,下面我们查看一下MNIST数据集的情况。
import tensorflow as tf
import tensorflow_datasets as tfdt
# 查看MNIST数据集信息
Mnist, info = tfdt.load('mnist', with_info=True)
train, test = Mnist['train'], Mnist['test']
print(info)
执行上述代码,我们首先看不到不一定是直接打印的MNIST数据集信息,很有可能是正在下载数据集的信息,如下所示。
只有当数据集加载完成后我们才会看到数据集相关的信息。
信息中包含了数据集的名称,整个数据集的大小,数据集中图片的尺寸、标签的格式,类别等等非常详细的信息,对于我们快速了解数据集非常有用,可以帮助我们快速的使用数据集。
另外介绍几个函数,用于从numpy数组中创建数据集。
(1) from_tensor_slices(),用于接收一个或者多个numpy张量,支持批处理操作。
(2)from_tensor,与(1)类似,不同是不支持批处理操作
(3)from_generator(),从生成器获取输入。
既然可以创建,那必然能实现相应的操作,例如数据的变换等。
(1)batch(), 设置划分数据集的大小
(2),shuffle(),随机对数据集执行打乱操作
(3)map(),用函数对数据做映射变换
(4)filter(),用函数对数据做映射变换。
(5)repeat(),复制数据
下面的博客中将一一的使用这些函数,到时候再做更详细的解析。