tensorflow中的datasets

134 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

tensorflow2.0版本之后相比于1.0版本更方便的一个地方是引入了tensorflow数据集,它可以以tf.data和numpy的形式加载进来。datasets的产生,使得大家对于解决某些数据集的繁琐工作量大大减少,通过td.data API可以构建更高效的输入。

对于datasets的安装方法如下:

pip install tensorflow-datasets

安装时大概如下面的图所示。

1.PNG

安装完成后下一步就是直接上手测试,首先查看一下它包含了哪些数据集,下次需要时可以直接导入。

import tensorflow as tf
import tensorflow_datasets as tfdt

# 查看包含的数据集
data_list = tfdt.list_builders()
print(data_list)
print(len(data_list))

结果如下

2.PNG

从打印的结果来看,包含的数据集的数量是真的不少,有310种,这对于学习来说足够用了。就拿上一篇博客提到的MNIST数据集来说,下面我们查看一下MNIST数据集的情况。

import tensorflow as tf
import tensorflow_datasets as tfdt

# 查看MNIST数据集信息

Mnist, info = tfdt.load('mnist', with_info=True)
train, test = Mnist['train'], Mnist['test']
print(info)

执行上述代码,我们首先看不到不一定是直接打印的MNIST数据集信息,很有可能是正在下载数据集的信息,如下所示。

3.PNG

只有当数据集加载完成后我们才会看到数据集相关的信息。

4.PNG

信息中包含了数据集的名称,整个数据集的大小,数据集中图片的尺寸、标签的格式,类别等等非常详细的信息,对于我们快速了解数据集非常有用,可以帮助我们快速的使用数据集。

另外介绍几个函数,用于从numpy数组中创建数据集。

(1) from_tensor_slices(),用于接收一个或者多个numpy张量,支持批处理操作。

(2)from_tensor,与(1)类似,不同是不支持批处理操作

(3)from_generator(),从生成器获取输入。

既然可以创建,那必然能实现相应的操作,例如数据的变换等。

(1)batch(), 设置划分数据集的大小

(2),shuffle(),随机对数据集执行打乱操作

(3)map(),用函数对数据做映射变换

(4)filter(),用函数对数据做映射变换。

(5)repeat(),复制数据

下面的博客中将一一的使用这些函数,到时候再做更详细的解析。