用自动编码器进行半监督学习的教程

615 阅读7分钟

对于大多数现实世界的问题来说,获得标记的数据既费时又费钱。

例如,假设你想训练一个模型来预测r/wallstreetbets的帖子是否包含对某只股票的积极或消极情绪。你可以自动编写一个快速脚本,从子reddit上抓取帖子,但要花几个小时来阅读这些帖子,并将每个帖子标记为正面或负面。

如果制作标签需要专家的意见,就像分析电子健康记录一样,那怎么办?那么标记的数据可能需要花费数万美元来制作!这就是半监督学习。

半监督学习是一种机器学习方法,它使用少量的标记数据和大量的未标记数据来解决监督学习问题。在这篇文章中,我将解释什么是自动编码器,解释自动编码器的一些用途,并介绍我使用时尚的MNIST数据集对半监督学习进行的一个小案例研究的结果。

什么是自动编码器?

自动编码器是一种神经网络,能够在未标记的数据中找到模式和结构。一个典型的自动编码器被训练成输出其输入的副本,但对网络进行了限制,使其不能轻易地将输入复制到输出。

例如,网络的隐藏层通常比输入和输出层的维度低,迫使模型以低维格式表示输入数据,然后用这种格式来重建输入。

这不是一个现实的例子,但上图显示了一个网络,它采用了一个8 x 8的数组(通常代表2D图像或其他东西),并使其最终将特征映射到一个16维的矢量,并返回到一个8 x 8的数组。通过像这样给网络添加一个 "瓶颈",我们希望中心的16维向量中的每个元素都代表数据的某种有用特征。

因此,我们可以在未标记的数据上训练自动编码器,将数据本身作为模型的目标输出。我们希望它能捕捉到数据中有用的特征,使我们能够在相关标记数据的小数据集上训练一个有效的模型。

自动编码器是如何建立的?

一个自动编码器是由一个编码器和一个解码器组成的。

不现实的自动编码器结构

编码器是由通常逐渐变小的层组成的,输出一个包含输入数据的编码表示的层。

然后,解码器将该编码表示作为输入,通常是编码器的镜像副本,旨在输出原始输入数据。

有时,编码器和解码器中使用的权重是 "绑在一起的"。解码器使用编码器权重的转置形式,以节省训练期间的时间和资源。这是因为解码器在训练完自动编码器后通常不会被使用,所以它的实现对问题并不那么重要。

自动编码器的用途

一般来说,自动编码器只是一个被训练来重现其输入的模型,有一些约束条件。然而,这个约束并不需要强加于模型本身,有些模型可以有比输入/输出层更大的隐藏层。

例如,你可以训练一个自动编码器对图像进行去噪。你采取一些输入图像,并以某种方式添加人工噪声,并将噪声图像作为输入。因此,该模型被训练成在看到噪声版本后产生原始图像。它仍然被迫在一定程度上拾取输入的有意义的特征。

一个类似的例子是训练一个模型来消除图像的模糊。

自动编码器也可以用于探索性数据分析,我将在下一节中展示。

案例研究。时尚MNIST

时尚MNIST是一个由28x28灰度图像组成的数据集,包括不同的服装物品。总共有10个服装类别,目标是预测每张图片所属的服装类别。

时尚MNIST(图片来自谷歌)

我的训练集包含55000张有标签的服装图片,但在这个项目中我想假装只有5000张。

自动编码器

这是我的基本自动编码器的结构。基本上,它需要一个28x28的图像,最终把它变成一个3x3x64的表示(有点,这里不解释卷积网络),并最终再次输出一个28x28的表示。

这是一些它生成的图像,在同时对55000张服装图像作为输入和目标进行训练之后。

使用自动编码器进行降维

我在编码器上附加了一个全局平均池层,这样它就能输出一个64维的向量,而不是3x3x64的矩阵。然后我使用不同的降维方法将数据的维度降低到2维,这样我就可以为我的数据绘制代表点。

这是自动编码器可以用于可视化的一种方式。我们可以看到,鞋类在右下方,裤子在左下方,衣服/包在上面。有些类别很容易分开,我们已经可以看到清晰的集群形成,而其他类别可能很难区分(例如衬衫和其他东西,在这个视觉上有很多重叠的部分)。

我现在想在5000张有标签的图像的一个小子集上训练一个模型,看看它能做得多好。我还想把它与在55000张图片的完整训练集上训练的模型进行比较。

使用编码器进行监督式学习

从上面的视觉效果来看,我们可以看到,自动编码器产生的编码可以清楚地识别输入图像中的有用特征。请记住,在自动编码器中,"编码 "实际上表示为64个3x3的特征图(3x3x64),但我们把这些图缩减为一个2维的向量来进行可视化。编码可能包含比上面的图形更有用的数据。

使用有限的标记数据,我们想训练一个相当准确的模型,能够对时尚的MNIST数据集进行预测。

首先,我必须设计这个模型。我想出了一个简单的模型架构,并添加了一些特征以防止过度拟合。

我简单地将输出的特征图平铺成一个576维的向量,然后加入一些具有不同程度Dropout的全连接网络层。

如果这没有什么意义也没关系,主要的想法是我把编码器连接到我自己的网络。我现在把这些编码作为输入到一个神经网络中,这个神经网络是为了预测10个时尚的MNIST类别而建立的。

结果

我训练了包含编码器的模型,并能够达到89%的训练准确率和86%的验证准确率。这表明该模型并没有过度拟合,而且在服装分类方面相当有效。我没有对模型进行太多的微调,所以这相当令人印象深刻

我还在整个训练数据集上训练了一个具有类似架构的卷积网络,验证准确率达到了88%。

当我在较小的(5000个观察)数据集上训练该模型时,它只能达到80%左右的验证准确率。

总的来说,这表明我们可以利用半监督学习的自动编码器来提高我们模型的性能。编码器模型是在少了11倍的数据上训练出来的,并取得了不错的准确率。

在数据科学的实际应用中,在评估一个特定的方法时,必须考虑许多与手头问题无关的问题,例如解决方案的效率和速度。希望这篇文章可以作为一个很好的介绍,介绍众多方法中的一种,可以帮助实现这一点。