boston_housing
样本包含 20 世纪 70 年代末波士顿郊区不同地点的房屋的 13 个属性。目标是某个地点房屋的中位值(以千美元为单位)。
tf.keras.datasets.boston_housing.load_data(
path='boston_housing.npz', test_split=0.2, seed=113
)
x_train, x_test:具有包含训练样本(对于 x_train)或测试样本(对于 y_train)的形状的 numpy 数组(num_samples, 13) 。
y_train、y_test:(num_samples,)包含目标标量的 numpy 形状数组。目标是通常在 10 到 50 之间的浮动标量,代表房价(以千美元为单位)。
参数:
path | 本地缓存数据集的路径。相对于 ~/.keras/datasets |
|---|---|
test_split | 保留作为测试集的数据的一部分。 |
seed | 用于在计算测试分割之前打乱数据的随机种子。 |
数据查看
import tensorflow as tf
from tensorflow.keras import datasets
(x_train, y_train), (x_test, y_test) = datasets.boston_housing.load_data(path='boston_housing.npz', test_split=0.2, seed=113)
x_train.shape, y_train.shape, x_test.shape, y_test.shape
x_train
y_train
cifar10
这是一个包含 50,000 张 32x32 彩色训练图像和 10,000 张测试图像的数据集,标记了 10 多个类别。
tf.keras.datasets.cifar10.load_data()
类别:
| Label | Description |
|---|---|
| 0 | airplane |
| 1 | automobile |
| 2 | bird |
| 3 | cat |
| 4 | deer |
| 5 | dog |
| 6 | frog |
| 7 | horse |
| 8 | ship |
| 9 | truck |
- x_train:uint8 NumPy 灰度图像数据数组 (50000, 32, 32, 3),包含训练数据。像素值范围从 0 到 255。
- y_train:uint8 NumPy 标签数组(0-9 范围内的整数),具有(50000, 1)训练数据的形状。
- x_test:uint8 NumPy 灰度图像数据数组 (10000, 32, 32, 3),包含测试数据。像素值范围从 0 到 255。
- y_test:uint8 NumPy 标签数组(0-9 范围内的整数),具有(10000, 1)测试数据的形状。
数据查看
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']
plt.figure(figsize=(20,5))
for i in range(20):
plt.subplot(2,10,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(x_train[i])
plt.xlabel(class_names[y_train[i,0]])
plt.show()
cifar100
这是一个包含 50,000 张 32x32 颜色训练图像和 10,000 张测试图像的数据集,标记了 100 多个细粒度类,这些细粒度类又分为 20 个粗粒度类。
tf.keras.datasets.cifar100.load_data(
label_mode='fine'
)
label_mode | “fine”、“coarse”之一。如果是“fine”,则类别标签是细粒度标签,如果是“coarse”,则输出标签是粗粒度超类。 |
|---|
数据查看
(x_train, y_train), (x_test, y_test) = datasets.cifar100.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape
plt.figure(figsize=(20,5))
for i in range(20):
plt.subplot(2,10,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(x_train[i])
plt.xlabel(y_train[i,0])
plt.show()
fashion_mnist
这是一个包含 10 个时尚类别的 60,000 张 28x28 灰度图像的数据集,以及一个包含 10,000 张图像的测试集。该数据集可用作 MNIST 的直接替代品。
| Label | Description |
|---|---|
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |
数据查看
(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape
plt.figure(figsize=(20,5))
for i in range(20):
plt.subplot(2,10,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(x_train[i])
plt.xlabel(y_train[i])
plt.show()
imdb
这是来自 IMDB 的 25,000 条电影评论的数据集,按情绪(正面/负面)标记。评论已经过预处理,每个评论都被编码为单词索引(整数)列表。为了方便起见,单词按数据集中的总体频率进行索引,例如整数“3”编码数据中第三个最常见的单词。这允许快速过滤操作,例如:“仅考虑前 10,000 个最常见的单词,但消除前 20 个最常见的单词”。
按照惯例,“0”并不代表特定的单词,而是用于对填充令牌进行编码。
tf.keras.datasets.imdb.load_data(
path='imdb.npz',
num_words=None,
skip_top=0,
maxlen=None,
seed=113,
start_char=1,
oov_char=2,
index_from=3,
**kwargs
)
path | 缓存数据的位置。相对于 ~/.keras/dataset |
|---|---|
num_words | int 或None。单词按照它们出现的频率(在训练集中)进行排名,并且只保留最频繁的单词。任何频率较低的单词都将被oov_char 代替。如果时None,则保留所有单词。 |
skip_top | 跳过前 N 个最常出现的单词(这可能没有提供信息)。这些单词将在数据集中显示为oov_char 。当为 0 时,不跳过任何单词。 |
maxlen | int 或None。最大序列长度。任何更长的序列将被截断。None表示不截断。 |
seed | 随机种子。 |
start_char | 序列的开始将用此字符标记。0 通常是填充字符。默认为1. |
oov_char | 词汇外的字符。由于num_words或skip_top限制而被删除的单词将被替换为该字符。 |
index_from | 使用此索引或更高索引对实际单词进行索引。 |
get_word_index
获取将单词映射到 IMDB 数据集中索引的字典。
tf.keras.datasets.imdb.get_word_index(
path='imdb_word_index.json'
)
word_index = datasets.imdb.get_word_index()
type(word_index)
word_index
数据查看
(x_train, y_train), (x_test, y_test) = datasets.imdb.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape
len(x_train),len(x_train[0]),len(x_train[1]),len(x_train[2])
print(y_train)
mnist
这是一个包含 60,000 张 10 位数字的 28x28 灰度图像的数据集,以及一个包含 10,000 张图像的测试集。
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train.shape,x_test.shape,y_train.shape,y_test.shape
plt.figure(figsize=(20,5))
for i in range(20):
plt.subplot(2,10,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(x_train[i],cmap=plt.cm.binary)
plt.xlabel(y_train[i])
plt.show()
reuters
这是来自路透社的 11,228 条新闻专线的数据集,标记了超过 46 个主题。
每个新闻专线都被编码为单词索引(整数)列表。为了方便起见,单词按数据集中的总体频率进行索引,例如整数“3”编码数据中第三个最常见的单词。这允许快速过滤操作,例如:“仅考虑前 10,000 个最常见的单词,但消除前 20 个最常见的单词”。
按照惯例,“0”并不代表特定的单词,而是用于编码任何未知的单词。
tf.keras.datasets.reuters.load_data(
path='reuters.npz',
num_words=None,
skip_top=0,
maxlen=None,
test_split=0.2,
seed=113,
start_char=1,
oov_char=2,
index_from=3,
**kwargs
)
参数和imdb数据集差不多。
get_word_index
和imdb数据集的一样。
tf.keras.datasets.reuters.get_word_index(
path='reuters_word_index.json'
)
实际的单词索引从 3 开始,保留 3 个索引:0(填充)、1(开始)、2(oov)。
例如,“the”的单词索引为 1,但在实际训练数据中,“the”的索引将为 1 + 3 = 4。反之亦然,要使用此映射将训练数据中的单词索引转换回单词,索引需要减去3。