线性回归

106 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第7天,点击查看活动详情

之前我们都是根据经验,直接使用了线性回归来预测过预测塞浦路斯的生活满意度。对于线性回归的内部流程其实还并没有真正了解。当然在项目实施中,我们可能会为了尽快做出效果,先会选择一个不错的但其实并不十分了解内部实施细节的模型先完成需求。

但是如果能了解内部机制,对于我们的调优有很好的帮助,能够更快速的定位错误,去修改模型直至效果最好。接下来我们再仔细了解一下线性回归。

线性回归

在预测生活满意度是,我们有得到过如下公式:

生活满意度=θ0+θ1×人均GDP生活满意度=\theta_{0} + \theta_{1}\times 人均GDP

现在我们继续深入了解其作用,我们可以看出生活满意度是关于人均GDP一个线性公式,或者说线性模型就是一个对输入特征加权求和

  • 输入特征 这边就是人均GDP
  • 特征权重 θ1\theta_{1}
  • 偏置项 一个常数θ0\theta_{0},让线性回归能够拟合更好

我们可以按照上面的内容推出更加普遍的一个公式:

y^=θ0+θ1×x1++θn×xn\hat{y} = \theta_{0} + \theta_{1}\times x_{1}+ \cdot \cdot \cdot + \theta_{n}\times x_{n}

其中:

  • y^\hat{y} 是我们的预测值
  • xix_{i} 是第i个特征值,就比如我们预测房价有很多的可以参考的点如房间数目、收入和地理位置等

我们平时会用向量的表示比较多,因此我们也可以写成:

y^=hθ(x)=θx\hat{y} = h_{\theta}(x) = \theta \cdot x
  • θ\theta 需要注意这边的θ\theta是一个参数的向量,如下所示:
    • (θ0θ1θm)\begin{pmatrix} \theta_{0} \\ \theta_{1}\\ \vdots \\ \theta_{m} \\ \end{pmatrix}
  • xx 就是我们的特征组成的向量,和上面一样,但是要注意x1x_1始终是1

OK,是不是对于线性模型有了更进一步的了解。我们之间都是先根据经验,然后根据需求推出可能是适合线性模型后,我们是不是就要进行训练了。训练之后我们还需要进行测量模型的好坏。这时候我们就需要使用到均方根误差(RMSE)和均方误差(MSE)。我们要做的不断优化模型,最小化他们。也就是我们要找出使得hθh_{\theta}最小的θ\theta值。

标准方程

这里我们就需要用到一个标准方程,他长得如下样子:

θ^=(XTX)1XTy\hat{\theta} = (X^{T}X)^{-1}X^{T}y
  • θ^\hat{\theta}就是我们要求的使得成本函数最小的θ\theta
  • y 就是我们的标签向量

我们可以简单的实现一下:

# 这里是100个特征值组成的向量,注意和之前我们的看的mnist的数据集表现得形式不同,这里是以列向量表示
X = 2 * np.random.rand(100,1)
# 我们的目标值向量、标签
y = 4 + 3 * X + np.random.randn(100,1)

# 补充上我们的X0,他也叫bias term,所以命名为X_b,因为只能是1,所以全部特征值填充为1
X_b = np.c_[np.ones((100,1)),X]
# numpy提供有线性代数模块np.linalg,其中的inv()函数可以用于我们求逆
# numpy的有一个属性T,可以直接返回矩阵转置
# dot()函数就是点乘,点乘的计算是X·y = x0·y0+x1·y1 + ··· + xn·yn 
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)

# 输出theta_best
[[4.14378059]
 [9.84417013]]
 
# 以上面求出的为基础,我们做预测

# 需要预测的两个参数
X_new = np.array([[0],[2]])
# 忘了补充上我们的bais term
X_new_b = np.c_[np.ones((2,1)),X_new] 
y_pred = X_new_b.dot(theta_best)
# 输出
[[4.14378059]
 [9.84417013]]

读者可以自行去画一条我们的预测值的线,将他放入我们的训练的X和y组成的散点图中,就能看到一条线在杂乱的散点中穿过。

这下我们是不是对之间用的LinearRegression的执行流程有了一定的了解,当然求theta的方法还是有点不一样:θ^=X+y\hat{\theta} = X^{+}y,其中的X+X^{+}XX的伪逆,我们可以通过np.linalg.pinv(X)计算得到。

当一些(XTX)(X^{T}X)存在不可逆时,伪逆还是有定义的。