前面两节介绍了TensorFlow安装和基本概念,下面我们在实战中进一步体会TensorFlow的用法。这里需要大家具备相关的背景知识,如神经网络,relu, softmax, 反向传播,卷积等等。这一节主要讲数据探索和基本函数准备。
数据采用斯坦福SVHN数据集(The Street View HouseNumbers Dataset)。
格式为MATLAB文件,可从scipy.io中import loadmat读取,这里Python 3.5版本。下面为具体代码实现,在注释部分进行讲解。
- #首先引入所需要的基本包
- importpandasaspd
- importnumpyasnp
- importmatplotlib.pyplotasplt
- importtensorflowastf
- #读matlab文件
- fromscipy.ioimportloadmatasload
- #读取数据
train= load('/Users/chenbin/Desktop/TensorFlow/test_ml/train_32x32.mat')
test= load('/Users/chenbin/Desktop/TensorFlow/test_ml/test_32x32.mat')- #这时后,我们可以打印数据的shape进行观察。
print(train['X'].shape)
print(train['y'].shape)
print(test['X'].shape)
print(test['y'].shape)- """
- 结果为:
- (32,32, 3, 73257)
- (73257,1)
- (32,32, 3, 26032)
- (26032,1)
- 训练集为73257个样本,测试集26032个样本,每个样本是32*32像素,3个像素通道,后面我们将其转化为1个通道,用灰度图显示,另外会将X样本量改为第一位置,符合我们平时处理的规格。
- """
- """
- 下面我们定义三个函数reformat(samples,labels) normalize(samples),inspect(dataset, labels, i):
- reformat用来改变原始数据格式,并对lable进行独热编码(one-hot encoding)。
- normalize用来对数据进行灰度化(
- 将三色通道转化为单色通道),把数据映射到 -1.0 ~ +1.0之间。
- inspect将图片显示有助于观察。
- """
defreformat(samples, labels):
# 改变原始数据的形状
# ( 0 1 2 3) ( 3 0 1 2)
# (图片高,图片宽,通道数,图片数) -> (图片数,图片高,图片宽,通道数)
new = np.transpose(samples, (3,0,1,2)).astype(np.float32)
# labels 变成 one-hot encoding,[2] -> [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
# digit 0 , represented as 10
# labels 变成 one-hot encoding,[10] -> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
labels = np.array([x[0]forx inlabels])
one_hot_labels = []
fornuminlabels:
one_hot = [0.0] *10
ifnum ==10:
one_hot[0] =1.0
else:
one_hot[num]=1.0
one_hot_labels.append(one_hot)
labels =np.array(one_hot_labels).astype(np.float32)
returnnew, labels
defnormalize(samples):
"""- 灰度化: 从三色通道 -> 单色通道 省内存 ,加快训练速度
- (R + G + B) / 3
- 将图片从 0 ~ 255 线性映射到-1.0 ~ +1.0
- """
a = np.add.reduce(samples,keepdims=True, axis=3)# shape (图片数,图片高,图片宽,通道数),将samples沿着原来格式相加
a = a/3.0
returna/128.0-1.0
definspect(dataset, labels, i):
# 将图片显示出来
ifdataset.shape[3] ==1:
shape = dataset.shape
dataset =dataset.reshape(shape[0], shape[1], shape[2])
print(labels)
plt.imshow(dataset)
plt.show()- #这时候我们就可以测试一下,打印一张图片看看
train_samples = train['X']
train_labels = train['y']
_train_samples,_train_labels = reformat(train_samples, train_labels)
inspect(_train_samples,_train_labels,123)- #结果为:
[0.0.0.0.0.0.1.0.0.0.]
下一节我们开始正式建立神经网络。
更多免费技术资料可关注:annalin1203