pytorch入门练习:超简单的线性回归模型

487 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

包括数据创建,模型定义,超参定义,训练,测试等完整过程。

# 练习:一个超级简单的线性回归模型
import torch
import numpy as np
import torch.nn as nn


# 构造一组输入数据x及其对应标签y
x_values = [i for i in range(1, 10)]
x_train = np.array(x_values, dtype=np.float32)  # 构造numpy
x_train = x_train.reshape(-1, 1)  # 为防止出错,一般情况下将数据转换为矩阵格式,此处为列向量矩阵

y_values = [2 * i + 1 for i in x_values]   # 假设y=2*x+1
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1, 1)


# 定义模型
class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)  # 全连接层,线性回归就是不需要激活函数的全连接层

    def forward(self, x):  # 前向传播
        out = self.linear(x)  # 输入x,输出out
        return out


if __name__ == "__main__":
    input_dim = 1
    output_dim = 1
    model = LinearRegressionModel(input_dim, output_dim)  # 模型
    epochs = 100  # 训练次数
    learning_rate = 0.01  # 学习率
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  # 优化器
    loss_fun = nn.MSELoss()  # 损失函数,分类任务常用交叉熵损失函数,回归任务常用MSE

    # 训练
    for epoch in range(epochs):
        epoch += 1

        inputs = torch.from_numpy(x_train)  # 将ndarrey格式转换为Terson
        labels = torch.from_numpy(y_train)

        optimizer.zero_grad()  # 每次迭代前或结束后梯度清零,否则会累加
        outputs = model(inputs)  # 前向传播
        loss = loss_fun(outputs, labels)  # 计算损失
        loss.backward()  # 返向传播,自动计算梯度
        optimizer.step()  # 更新权重参数

        if epoch % 10 == 0:  # 每10个epoch输出一次损失值
            print("loss of No {0} epoch is:  {1}".format(epoch, loss.item()))

    # 测试
    predicted = model(torch.from_numpy(x_train).requires_grad_())  # 前向传播,计算输出
    predicted = predicted.data.numpy()  # 转换成ndarrey格式,便于画图制表等
    for j in x_values:
        print(x_values[j], predicted[j], (predicted[j]-1)/x_values[j])

程序运行结果:

loss of No 10 epoch is: 0.09011725336313248 loss of No 20 epoch is: 0.08299523591995239 loss of No 30 epoch is: 0.07643688470125198 loss of No 40 epoch is: 0.0703967958688736 loss of No 50 epoch is: 0.06483396887779236 loss of No 60 epoch is: 0.05971059575676918 loss of No 70 epoch is: 0.054992228746414185 loss of No 80 epoch is: 0.050646599382162094 loss of No 90 epoch is: 0.04664454981684685 loss of No 100 epoch is: 0.0429585836827755 2 [4.6931477] [1.8465738] 3 [6.764638] [1.921546] 4 [8.836127] [1.9590318] 5 [10.907617] [1.9815233] 6 [12.979107] [1.9965178] 7 [15.050596] [2.0072281] 8 [17.122086] [2.0152607] 9 [19.193575] [2.0215082]