统计学习方法--线性感知机

240 阅读2分钟

感知机定义

感知机是一个经典的二分类问题。假设我们有一系列的数据集

T={(x0,y0),(x1,y1),,(xm,ym)}T = \left\{\mathbf{(x_0}, y_0), (\mathbf{x_1}, y_1), \cdots, (\mathbf{x_m}, y_m)\right\}

其中, xRn\mathbf{x} \in \mathbb{R}^n, y{1,1}y \in \{-1, 1\}

假设存在一个参数向量 wRn\mathbf{w}\in \mathbb{R}^n 和一个参数 bRb \in \mathbb{R},可以使得y=wTx+by = \mathbf{w}^\mathrm{T}\mathbf{x}+b 构造出的超平面能恰好分开数据集中点。那么,我们的目标,就是求解出参数向量 w\mathbf{w} 和偏置项 bb

注意,我们提到的所有向量都是列向量

我们需要定义一个损失函数,最直观的方式是,假设MM表示误分类的点集,则误分类点数个数最小是我们的求解目标;不过这种方式得出的函数不可导,不容易计算。因此,我们选择一个等价的方式,计算所有误分类点到平面的距离总和,距离总和最小则说明超平面拟合越好。

先给出,空间 Rn\mathbb{R}^n 中的点到超平面的距离公式

d=wTx+bw2d = \frac{|\mathbf{w}^\mathrm{T}\mathbf{x}+b|}{\parallel \mathbf{w} \parallel_2}

w2\parallel \mathbf{w} \parallel_2 表示向量 w\mathbf{w}L2L_2范数。

因此,如果一个点xi\mathbf{x}_i被误分类了,则yi(wTxi+b)>0-y_i(\mathbf{w}^\mathrm{T}\mathbf{x}_i+b) \gt 0

结合距离公式,我们可以得出误分类点到超平面的距离是:

yi(wTxi+b)w2-\frac{y_i(\mathbf{w}^\mathrm{T}\mathbf{x}_i+b)}{\parallel \mathbf{w} \parallel_2}

对于所有的误分类点来说,距离总和应该是:

1w2xiMyi(wTxi+b)-\frac{1}{\parallel \mathbf{w} \parallel_2} \sum_{\mathbf{x}_i \in M}y_i(\mathbf{w}^\mathrm{T}\mathbf{x}_i+b)

如果我们不考虑系数,那么等价的最小距离总和应该是:

L(w,b)=xiMyi(wTxi+b)L(\mathbf{w}, b) = \sum_{\mathbf{x}_i \in M}y_i(\mathbf{w}^\mathrm{T}\mathbf{x}_i+b)

我们需要一个算法,来根据上述条件,求出w\mathbf{w}bb

训练方法

梯度下降方法

根据上面的距离公式,我们应该最小化L(w,b)L(\mathbf{w}, b),分别对w\mathbf{w}bb求解偏导:

wL(w,b)=L(w,b)w=xiMyixi\nabla_{\mathbf{w}}{L(\mathbf{w}, b)} = \frac{\partial{L(\mathbf{w}, b)}}{\partial{\mathbf{w}}} = \sum_{\mathbf{x}_i \in M}y_i\mathbf{x}_i
bL(w,b)=L(w,b)b=xiMyi\nabla_{b}{L(\mathbf{w}, b)} = \frac{\partial{L(\mathbf{w}, b)}}{\partial{b}}=\sum_{\mathbf{x}_i \in M}y_i

上面这种方式为批梯度下降,即每次要对全部的数据执行一个计算,然后更新梯度:

w=wηbL(w,b)\mathbf{w} = \mathbf{w} - \eta \nabla_{b}{L(\mathbf{w}, b)}
b=bηbL(w,b)b = b - \eta \nabla_{b}{L(\mathbf{w}, b)}

其中0<η<10 \lt \eta \lt 1表示学习速率。上述用到了向量求解偏导的方式,具体参考这篇文档

为了加快计算速度,我们可以使用随机梯度下降(SGD)的方式,即每次随机选择一个数据点进行跟新,则最终的更新方式为:

w=wηyixi\mathbf{w} = \mathbf{w} - \eta y_i \mathbf{x}_i
b=bηyib = b - \eta y_i

算法会一直迭代,直到M=M = \emptyset,或者L(w,b)ϵL(\mathbf{w}, b) \le \epsilon,这里ϵ\epsilon表示我们可以忍受的最小值。

对偶方法

未完待续 ......