tf.keras.datasets

79 阅读5分钟

boston_housing

样本包含 20 世纪 70 年代末波士顿郊区不同地点的房屋的 13 个属性。目标是某个地点房屋的中位值(以千美元为单位)。

tf.keras.datasets.boston_housing.load_data(
    path='boston_housing.npz', test_split=0.2, seed=113
)

x_train, x_test:具有包含训练样本(对于 x_train)或测试样本(对于 y_train)的形状的 numpy 数组(num_samples, 13) 。

y_train、y_test:(num_samples,)包含目标标量的 numpy 形状数组。目标是通常在 10 到 50 之间的浮动标量,代表房价(以千美元为单位)。

参数:

path本地缓存数据集的路径。相对于 ~/.keras/datasets
test_split保留作为测试集的数据的一部分。
seed用于在计算测试分割之前打乱数据的随机种子。

数据查看

import tensorflow as tf
from tensorflow.keras import datasets

(x_train, y_train), (x_test, y_test) = datasets.boston_housing.load_data(path='boston_housing.npz', test_split=0.2, seed=113)

x_train.shape, y_train.shape, x_test.shape, y_test.shape

x_train

y_train

cifar10

这是一个包含 50,000 张 32x32 彩色训练图像和 10,000 张测试图像的数据集,标记了 10 多个类别。

tf.keras.datasets.cifar10.load_data()

类别:

LabelDescription
0airplane
1automobile
2bird
3cat
4deer
5dog
6frog
7horse
8ship
9truck
  • x_train:uint8 NumPy 灰度图像数据数组 (50000, 32, 32, 3),包含训练数据。像素值范围从 0 到 255。
  • y_train:uint8 NumPy 标签数组(0-9 范围内的整数),具有(50000, 1)训练数据的形状。
  • x_test:uint8 NumPy 灰度图像数据数组 (10000, 32, 32, 3),包含测试数据。像素值范围从 0 到 255。
  • y_test:uint8 NumPy 标签数组(0-9 范围内的整数),具有(10000, 1)测试数据的形状。

数据查看

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']
plt.figure(figsize=(20,5))
for i in range(20):
    plt.subplot(2,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i])
    plt.xlabel(class_names[y_train[i,0]])
plt.show()

cifar100

这是一个包含 50,000 张 32x32 颜色训练图像和 10,000 张测试图像的数据集,标记了 100 多个细粒度类,这些细粒度类又分为 20 个粗粒度类。

tf.keras.datasets.cifar100.load_data(
    label_mode='fine'
)
label_mode“fine”、“coarse”之一。如果是“fine”,则类别标签是细粒度标签,如果是“coarse”,则输出标签是粗粒度超类。

数据查看

(x_train, y_train), (x_test, y_test) = datasets.cifar100.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape

plt.figure(figsize=(20,5))
for i in range(20):
    plt.subplot(2,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i])
    plt.xlabel(y_train[i,0])
plt.show()

fashion_mnist

这是一个包含 10 个时尚类别的 60,000 张 28x28 灰度图像的数据集,以及一个包含 10,000 张图像的测试集。该数据集可用作 MNIST 的直接替代品。

LabelDescription
0T-shirt/top
1Trouser
2Pullover
3Dress
4Coat
5Sandal
6Shirt
7Sneaker
8Bag
9Ankle boot

数据查看

(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape

plt.figure(figsize=(20,5))
for i in range(20):
    plt.subplot(2,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i])
    plt.xlabel(y_train[i])
plt.show()

imdb

这是来自 IMDB 的 25,000 条电影评论的数据集,按情绪(正面/负面)标记。评论已经过预处理,每个评论都被编码为单词索引(整数)列表。为了方便起见,单词按数据集中的总体频率进行索引,例如整数“3”编码数据中第三个最常见的单词。这允许快速过滤操作,例如:“仅考虑前 10,000 个最常见的单词,但消除前 20 个最常见的单词”。

按照惯例,“0”并不代表特定的单词,而是用于对填充令牌进行编码。

tf.keras.datasets.imdb.load_data(
    path='imdb.npz',
    num_words=None,
    skip_top=0,
    maxlen=None,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
    **kwargs
)
path缓存数据的位置。相对于 ~/.keras/dataset
num_wordsint 或None。单词按照它们出现的频率(在训练集中)进行排名,并且只保留最频繁的单词。任何频率较低的单词都将被oov_char 代替。如果时None,则保留所有单词。
skip_top跳过前 N 个最常出现的单词(这可能没有提供信息)。这些单词将在数据集中显示为oov_char 。当为 0 时,不跳过任何单词。
maxlenint 或None。最大序列长度。任何更长的序列将被截断。None表示不截断。
seed随机种子。
start_char序列的开始将用此字符标记。0 通常是填充字符。默认为1.
oov_char词汇外的字符。由于num_wordsskip_top限制而被删除的单词将被替换为该字符。
index_from使用此索引或更高索引对实际单词进行索引。

get_word_index

获取将单词映射到 IMDB 数据集中索引的字典。

tf.keras.datasets.imdb.get_word_index(
    path='imdb_word_index.json'
)
word_index = datasets.imdb.get_word_index()
type(word_index)

word_index

数据查看

(x_train, y_train), (x_test, y_test) = datasets.imdb.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape

len(x_train),len(x_train[0]),len(x_train[1]),len(x_train[2])

print(y_train)

mnist

这是一个包含 60,000 张 10 位数字的 28x28 灰度图像的数据集,以及一个包含 10,000 张图像的测试集。

(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape

plt.figure(figsize=(20,5))
for i in range(20):
    plt.subplot(2,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i],cmap=plt.cm.binary)
    plt.xlabel(y_train[i])
plt.show()

reuters

这是来自路透社的 11,228 条新闻专线的数据集,标记了超过 46 个主题。

每个新闻专线都被编码为单词索引(整数)列表。为了方便起见,单词按数据集中的总体频率进行索引,例如整数“3”编码数据中第三个最常见的单词。这允许快速过滤操作,例如:“仅考虑前 10,000 个最常见的单词,但消除前 20 个最常见的单词”。

按照惯例,“0”并不代表特定的单词,而是用于编码任何未知的单词。

tf.keras.datasets.reuters.load_data(
    path='reuters.npz',
    num_words=None,
    skip_top=0,
    maxlen=None,
    test_split=0.2,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
    **kwargs
)

参数和imdb数据集差不多。

get_word_index

和imdb数据集的一样。

tf.keras.datasets.reuters.get_word_index(
    path='reuters_word_index.json'
)

实际的单词索引从 3 开始,保留 3 个索引:0(填充)、1(开始)、2(oov)。

例如,“the”的单词索引为 1,但在实际训练数据中,“the”的索引将为 1 + 3 = 4。反之亦然,要使用此映射将训练数据中的单词索引转换回单词,索引需要减去3。

get_label_names

数据查看