(本文代码部分使用numpy编写)
在线性回归预测中,最简单的就是闭式解closed form solution)也叫解析解(analytical solution).
一些严格的公式,给出任意的自变量就可以求出其因变量,也就是问题的解,例如咱们中学中最简单的一元一次方程:
y = wx + b
但是在实际预测中往往会出现一些偏差,我们就会添加一个误差参数e,公式就变成了:
y = wx + b + e
然而,在实际操作中咱们只会知道一些以往的输入x,以及对应的输出y.然后让你对之后过来的输入预测其输出,就如同下面的图集:
那么问题就来了,咱们怎么知道w,b的值?
为了回答这个问题,咱们可以构造一个损失函数loss
loss = ∑(wx + b - y)
当loss越小,说明函数预测越准确.当loss最小时,其w,b就是咱们所需要的.
我们应该知道,一个函数的的斜率总是指向函数极大点的,就如图:
那咱们就可以构造w.b的公式:
w = ∑(w - lr * dloss/dw)
b = ∑(b - lr * dloss/db)
其中lr为步长,提高w,b精度.
寻找loss最小值的过程就是个迭代过程,当w,b斜率为正时,w,b会减小,反正亦然.
值得注意的是,并不会出现一个最准确的值,loss只会在及小值附近徘徊
代码如下:
import numpy as np
#更新wbdef step_gradient(data, w, b, learn_rate):
data_length = len(data)
w_gradient, b_gradient = 0,0
for i in range(data_length):
x = data[i, 0]
y = data[i, 1]
#求w b偏导
w_gradient += (2 / data_length) * (w*x + b - y) * x
b_gradient += (2 / data_length) * (w*x + b - y)
new_w = w - (w_gradient * learn_rate)
new_b = b - (b_gradient * learn_rate)
return new_w,new_b
#迭代次数
def gradient_descent_runner(data, w, b, learn_rate, num_iteration = 0):
data_length = len(data)
total_loss = 0
for i in range(num_iteration):
w, b = step_gradient(data, w, b,learn_rate)
return w, b
#计算损失
def compute_error_for_line_given_points(data, w, b, learn_rate):
data_length = len(data)
total_loss = 0
for i in range(data_length):
w,b = step_gradient(data, w, b, learn_rate)
x = data[i, 0]
y = data[i, 1]
total_loss += ((x*w + b) - y) ** 2
return total_loss /data_lengthdef run():
data = np.genfromtxt("H:\机器学习\linear_regression_data\data.csv",delimiter=",")
learning_rate = 0.0001
inital_b = 0
init_w = 0
print("the number_itertation is : 0, the loss is {0},the w and b is {1},{2}".format(compute_error_for_line_given_points(data, init_w, inital_b, learning_rate), 0, 0))
w,b = gradient_descent_runner(data, init_w, inital_b, learning_rate, 1000)
print("the number_itertation is : 1000, the loss is {0},the w and b is {1},{2}".format(compute_error_for_line_given_points(data, w, b, learning_rate), w, b))
if __name__ == '__main__':
run()
总结:
线性回归问题主要是对参数w,b的求解,而w,b的求解需要反斜率方向前进一达到loss的最小值范围.