线性回归案例

812 阅读3分钟

1、案例背景

通常来说,薪水会随着工龄的增长而增长,不同行业的薪水增长速度有所不同。本案例要应用一元线性回归模型探寻工龄对薪水的影响,即搭建薪水预测模型,并通过比较多个行业的薪水预测模型来分析各个行业的特点。

2、读取数据

首先以IT行业为例进行分析,这里选取的是北京地区IT行业中工龄在0~8年范围内的100个IT工程师的月薪数据,存储在一个名为“IT行业收入表.xlsx”的Excel工作簿中。通过如下代码读取数据。

IT行业收入表.xlsx

import pandas as pd
    df = pd.read_excel("IT行业收入表.xlsx")
    df.head()

image.png

此时的工龄为自变量,薪水为因变量,通过如下代码进行自变量、因变量选取。

    x = df[["工龄"]]
    y = df["薪水"]

这里的自变量X必须写成二维结构形式,原因之前提过;而因变量Y写成一维结构形式即可,不过如果写成二维结构形式df[['薪水']],之后的模型也能运行。

通过如下代码可以将此时的散点图绘制出来。

from matplotlib import pyplot as plt
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ["SimHei"]

plt.scatter(x,y)
plt.xlabel("工龄")
plt.ylabel("薪水")
plt.show()

image.png

3、模型搭建

from sklearn.linear_model import LinearRegression
regr = LinearRegression()
regr.fit(x,y)

4、模型可视化

plt.scatter(x,y)
plt.plot(x,regr.predict(x),color="red")
plt.xlabel("工龄")
plt.ylabel("薪水")
plt.show()

image.png

5、线性回归方程构建

# 系数
regr.coef_[0]

2497.1513476046866

# 截距
regr.intercept_

10143.131966873787

运行结果如下。因此,拟合得到的一元线性回归方程为y=2497x+10143。

6、模型优化

一元线性回归模型其实还有一个进阶版本——一元多次线性回归模型,比较常见的是一元二次线性回归模型,其形式可以表示为如下所示的公式。

y=ax2+bx+c

之所以还需要研究一元多次线性回归模型,是因为有时真正契合的趋势线可能不是一条直线,而是一条曲线。如下图所示,根据一元二次线性回归模型绘制的曲线更契合散点图呈现的数据变化趋势。

from sklearn.preprocessing import PolynomialFeatures
poly_reg = PolynomialFeatures(degree=2)
x_ = poly_reg.fit_transform(x)
  • 第1行代码引入用于增加一个多次项内容的模块PolynomialFeatures。
  • 第2行代码设置最高次项为二次项,为生成二次项数据(x2)做准备。
  • 第3行代码将原有的X转换为一个新的二维数组X_,该二维数组包含新生成的二次项数据(x2)和原有的一次项数据(x)。x_的内容为如下图所示的一个二维数组,其中第1列数据为常数项(其实就是x0),没有特殊含义,对分析结果不会产生影响;第2列数据为原有的一次项数据(x);第3列数据为新生成的二次项数据(x2)。

生成二次项数据后,就可以通过如下代码获得一元二次线性回归模型,和之前代码的区别只是把X换成了x_。

image.png

image.png

regr = LinearRegression()
regr.fit(x_,y)

然后通过如下代码就可以绘制上面的曲线图了,注意此时predict()函数中传入的参数是x_。

plt.scatter(x,y)
plt.plot(x,regr.predict(x_),color="red")
plt.xlabel("工龄")
plt.ylabel("薪水")
plt.show()

image.png

通过和之前类似的代码可以获取此时的一元二次线性回归方程的系数a、b和常数项c,代码如下。

image.png

第1行为系数,有3个数:第1个数0对应X_中常数项的系数,这也是为什么之前说X_的常数项不会对分析结果产生影响;第2个数对应X_中一次项(x)的系数,即系数b;第3个数对应X_中二次项(x2)的系数,即系数a。第2行的数对应常数项c。因此,拟合得到的一元二次线性回归方程为y=400.8x2-743.68x+13988。