持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第20天,点击查看活动详情
像之前的遇见的mnist中的手写图像中我们知道的维度就挺高的,而事实上机器学习可能还会遇见上万甚至百万及以上的特征,特征越高训练就会越慢,这被称为维度的诅咒(curse of dimensionality)。
维度的诅咒
我们平时能看见或者想象到的一般都是在1-3维之间,如果要想一个简单的思维超立方体,可能也会比较困难。
在低纬度之间找两个点,如在平面之中平均距离是0.52,在立方体中就变成平均距离0.66,维度越高两个点的平均距离就会越来越远。因此如果在高维度中,数据集很大可能是稀疏的,绝大部分的训练实例可能是相距很远的,因此对于预测的结果很可能更加的不可靠,容易过拟合。
因此面对高维度的数据,我们需要使用一些降维算法。在此之间,我们先来看看常用的减少维度的两种方式:投影和流形学习。
投影
投影还算比较好理解,就是将原来的高维度的数据,选择低维度的一个空间坐标,将自己垂直投射到低纬度的子空间中。 可以看一个简单的瑞士卷小数据的例子(使用来自Scikit-Learn的datasets模块下的make_swiss_roll()):
# 获得一个小型瑞士卷数据集,然后将其画出
X, t = make_swiss_roll(n_samples=1000, noise=0.2, random_state=42)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, cmap=plt.cm.hot)
# ...省去非必要画图步骤代码
图1 瑞士卷数据集
使用make_swiss_roll()方法后,我们将其可视化可以大致看出一个瑞士卷的摸样。我们可以比如只取X的第一列和第二列的数据,就可以将上面的三位图像投影到2维上,如下图所示:
投影到平面的代码-plt.scatter(X[:, 0], X[:, 1], c=t, cmap=plt.cm.hot)
展开的部分代码-plt.scatter(t, X[:, 1], c=t, cmap=plt.cm.hot)
图2 瑞士卷投影到平面(左图)和直接将瑞士卷在2维空间展开(右图)
这边就是简单的一个投影案例,相信大家应该也对投影有个简单的认知了。
流形学习
刚才提到的瑞士卷其实也是一个2D流行。2D流行就是表示在高维度的空间中弯曲和扭曲的2D形状。
而很多的降维算法就是基于此,通过对训练实例所在的流行进行建模,称为流行学习。他依赖于流行假设:该假设认为大多现实世界的高维数据集都接近于低维流形。
还有另一个隐式假设:如果能够用流行的低维空间来表示,执行的任务就能够更加简单。比如上面的被展开的瑞士卷很明显将会更容易进行判断,而在3维中决策边界都会变得很复杂。