动手学撸一个Pytorch线性回归

521 阅读3分钟

一.线性回归模型

1.在数据呈现随自变量增长,因变量也随着相应的增长,例如在中学学到的一次函数:y = ax + b
这个里面自变量x与因变量y成线性关系。当数据无法全部呈现线条状分布,这个时候我们可以认定为这些数据含有噪声,当我们需要更好将这批数据用线性关系式表达出来时,即输入x得到y的值 2.在1中为简单的一次函数,属一元线性模型,我们需要最小二乘法拟合函数+梯度下降算法优化拟合度

1.1.最小二乘法

最小二乘法在高等数学中有介绍,这里就不再多加介绍了

1.2.梯度下降算法简介

梯度下降法是通过迭代找到目的极小值,或收敛到极小值的方法。(这里大家需要判定极小值最小值之间的区别)
此方法使用于无法对函数全局有掌握而又需要找到函数的极小值的情况,当函数在全域里为凹函数,则可以较快的收敛到全局的最大值(最大值的相反数为最小值)
从数据分析的角度可以理解为函数的梯度于函数的等势面垂直,此时的梯度方向是函数变化最快的方向,当函数可微时,通过反复迭代,每一步向当前点的梯度方向下降一定距离的方法,便可以找到极小值

1.3.反向传播

使用方向传播是求解梯度下降,在感知机学习器和BP算法中应用的广为使用。大家可以参考BP中的简介\

二.撸Pytorch线性回归流程

在使用pytorch训练模型的过程大致可以分为以下五个部分: 2.1 输入数据,使用np.array
2.2 定义参数和模型,构建一个未初始化的5 * 3的张量
2.3 定义损失函数
2.4 反向传播得到梯度
2.5 使用梯度优化更新模型参数\

demo:

import torch
import numpy as np
from torch.autograd import Variable

# 输入数据
XTrain = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
                  [9.779], [6.182], [7.59], [2.167], [7.042],
                  [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)

YTrain = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
                  [3.366], [2.596], [2.53], [1.221], [2.827],
                  [3.465], [1.65], [2.904], [1.3]],
                  dtype=np.float32)
# 转换为Tensor
XTrain = torch.from_numpy(XTrain)
YTrain = torch.from_numpy(YTrain)
# 定义W和b
w = Variable(torch.randn(1), requires_grad=True)
b = Variable(torch.zeros(1), requires_grad=True)
# 构建模型
XTrain = Variable(XTrain)
YTrain = Variable(YTrain)

# 定义线性模型
def linear_model(XTrain):
    Y = w * XTrain + b
    return Y

# 构建损失函数
def GetLoss(Y, YTrain):
    Loss = torch.mean((Y - YTrain) ** 2)
    return Loss

Y = linear_model(XTrain)
Loss = GetLoss(Y, YTrain)
# 反向传播求梯度
Loss.backward()
# 利用梯度下降优化
for i in range(50):
    Y = linear_model(XTrain)
    Loss = GetLoss(Y, YTrain)
    w.grad.zero_()
    b.grad.zero_()
    Loss.backward()
    w.data = w.data - 1e-2 * w.grad.data
    b.data = b.data - 1e-2 * b.grad.data
    print('epach: {},loss: {}'.format(i, Loss.item()))

结果:

epach: 0,loss: 2.9266693592071533
epach: 1,loss: 0.29842206835746765
epach: 2,loss: 0.24937935173511505
epach: 3,loss: 0.24807125329971313
epach: 4,loss: 0.24764862656593323
epach: 5,loss: 0.2472444325685501
epach: 6,loss: 0.2468426376581192
epach: 7,loss: 0.24644288420677185
epach: 8,loss: 0.24604518711566925
epach: 9,loss: 0.245649516582489
epach: 10,loss: 0.24525584280490875
epach: 11,loss: 0.24486425518989563
epach: 12,loss: 0.2444746196269989
epach: 13,loss: 0.24408702552318573
epach: 14,loss: 0.24370145797729492
epach: 15,loss: 0.24331779778003693
epach: 16,loss: 0.2429361343383789
epach: 17,loss: 0.2425564080476761
epach: 18,loss: 0.24217866361141205
epach: 19,loss: 0.24180279672145844
epach: 20,loss: 0.2414289116859436
epach: 21,loss: 0.2410569190979004
epach: 22,loss: 0.24068683385849
epach: 23,loss: 0.24031871557235718
epach: 24,loss: 0.2399524301290512
epach: 25,loss: 0.23958797752857208
epach: 26,loss: 0.23922544717788696
epach: 27,loss: 0.23886480927467346
epach: 28,loss: 0.23850592970848083
epach: 29,loss: 0.2381489872932434
epach: 30,loss: 0.23779381811618805
epach: 31,loss: 0.23744045197963715
epach: 32,loss: 0.23708899319171906
epach: 33,loss: 0.23673918843269348
epach: 34,loss: 0.2363913208246231
epach: 35,loss: 0.23604515194892883
epach: 36,loss: 0.2357008308172226
epach: 37,loss: 0.23535820841789246
epach: 38,loss: 0.23501737415790558
epach: 39,loss: 0.234678253531456
epach: 40,loss: 0.23434093594551086
epach: 41,loss: 0.23400531709194183
epach: 42,loss: 0.2336713820695877
epach: 43,loss: 0.23333920538425446
epach: 44,loss: 0.2330087125301361
epach: 45,loss: 0.23267993330955505
epach: 46,loss: 0.2323528230190277
epach: 47,loss: 0.23202744126319885
epach: 48,loss: 0.23170368373394012
epach: 49,loss: 0.2313815802335739

本文参考PYTORCH官网教程