机器学习入门 06.多项式线性回归

466 阅读3分钟

这是我参与更文挑战的第9天,活动详情查看: 更文挑战

多项式回归

多项式回归与我们之前讨论的简单线性回归概念相同,只是它的图像表示是使用曲线而不是直线.相比与简单线性回归它具有更多的参数,但是它依然只有一个自变量,它的公式表达如下:

y=b0+b1x1+b2x12...+bn+x1ny = b_0 + b_1 * x_1 + b_2 * x_1^2 ... + b_n + x_1^n

多项式回归的应用适合对抛物线进行拟合,这些基本用于定义或描述非线性现象,如:

  • 组织的生长速度
  • 传染病的发展
  • 人口增长速度

多项式线性回归只有一个自变量,可以更好的拟合抛物线,适合拟合类似疾病传播速度等抛物线图像.

因为因变量y可以使用线性参数 b0b_0...bnb_n来表达,所以我们仍然称多项式回归是线性回归的一种.

python - 多项式回归实现

假如我们有薪资数据如下,我们希望通过职位和等级来预判对应级别的薪资.

PositionLevelSalary
Business Analyst145000
Junior Consultant250000
Senior Consultant360000
Manager480000
Country Manager5110000
Region Manager6150000
Partner7200000
Senior Partner8300000
C-level9500000
CEO101000000

我们来比较一下简单线性回归和多项式回归的效果区别.

我们首先加载这些数据

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

### 导入数据
dataset = pd.read_csv('data.csv')
X = dataset.iloc[:,1:2].values
y = dataset.iloc[:,2].values
# 创建简单线性回归模型
from sklearn.linear_model import LinearRegression

line_regression = LinearRegression()
line_regression.fit(X,y)

# 创建多项式矩阵
from sklearn.preprocessing import PolynomialFeatures

ploy_regression = PolynomialFeatures(degree = 4)
X_ploy = ploy_regression.fit_transform(X)

# 使用多项式矩阵集合线性回归模型
line_regression_2 = LinearRegression()
line_regression_2.fit(X_ploy,y)

可以看到我们通过sklearn创建了使用简单线性回归模型拟合的line_regression,以及使用多项式回归模型拟合的line_regression_2

下来我们来测试他们的效果如何

# 线性回归模型图像
plt.scatter(X, y, color='red')
plt.plot(X, line_regression.predict(X), color='blue')
plt.title('Linear Regression')
plt.xlabel('Position Level')
plt.ylabel('Salary')
plt.show()

Figure 2021-06-21 091921.png

可以看到简单线性模型对实际情况的集合并不好,他不能很好的拟合曲线模型.

# 线性回归模型图像
plt.scatter(X, y, color='red')
plt.plot(X, line_regression_2.predict(X_ploy), color='green')
plt.title('Polynomal Regression')
plt.xlabel('Position Level')
plt.ylabel('Salary')
plt.show()

Figure 2021-06-21 091927.png

可以看到多项式线性回归相比简单线性回归可以更好的拟合数据中的抛物线图像

但是可以看到点与点之间的连接是直线连接的,所以图像并不是很平顺,我们可以通过添加插值来让图像变得更加平顺一些.

X_grid = np.arange(min(X), max(X), 0.1)
X_grid = X_grid.reshape(len(X_grid), 1)
plt.scatter(X, y, color='red')
plt.plot(X_grid, line_regression_2.predict(
    ploy_regression.fit_transform(X_grid)), color='black')
plt.title('Polynomal Regression')
plt.xlabel('Position Level')
plt.ylabel('Salary')
plt.show()

Figure 2021-06-21 091934.png

通过添加插值可以看到,图像已经变得更加平顺.

下来让我们使用拟合好的模型来预测数值

line_pred = line_regression.predict(np.array(6.5).reshape(1, -1))
// output - 330379
ploy_pred = line_regression_2.predict(
    ploy_regression.fit_transform(np.array(6.5).reshape(1, -1)))
// output - 158862

可以看到多项式回归模型的预测值更符合实际情况.