本文已参与「新人创作礼」活动,一起开启掘金创作之路。
包括数据创建,模型定义,超参定义,训练,测试等完整过程。
# 练习:一个超级简单的线性回归模型
import torch
import numpy as np
import torch.nn as nn
# 构造一组输入数据x及其对应标签y
x_values = [i for i in range(1, 10)]
x_train = np.array(x_values, dtype=np.float32) # 构造numpy
x_train = x_train.reshape(-1, 1) # 为防止出错,一般情况下将数据转换为矩阵格式,此处为列向量矩阵
y_values = [2 * i + 1 for i in x_values] # 假设y=2*x+1
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1, 1)
# 定义模型
class LinearRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(input_dim, output_dim) # 全连接层,线性回归就是不需要激活函数的全连接层
def forward(self, x): # 前向传播
out = self.linear(x) # 输入x,输出out
return out
if __name__ == "__main__":
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim, output_dim) # 模型
epochs = 100 # 训练次数
learning_rate = 0.01 # 学习率
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 优化器
loss_fun = nn.MSELoss() # 损失函数,分类任务常用交叉熵损失函数,回归任务常用MSE
# 训练
for epoch in range(epochs):
epoch += 1
inputs = torch.from_numpy(x_train) # 将ndarrey格式转换为Terson
labels = torch.from_numpy(y_train)
optimizer.zero_grad() # 每次迭代前或结束后梯度清零,否则会累加
outputs = model(inputs) # 前向传播
loss = loss_fun(outputs, labels) # 计算损失
loss.backward() # 返向传播,自动计算梯度
optimizer.step() # 更新权重参数
if epoch % 10 == 0: # 每10个epoch输出一次损失值
print("loss of No {0} epoch is: {1}".format(epoch, loss.item()))
# 测试
predicted = model(torch.from_numpy(x_train).requires_grad_()) # 前向传播,计算输出
predicted = predicted.data.numpy() # 转换成ndarrey格式,便于画图制表等
for j in x_values:
print(x_values[j], predicted[j], (predicted[j]-1)/x_values[j])
程序运行结果:
loss of No 10 epoch is: 0.09011725336313248 loss of No 20 epoch is: 0.08299523591995239 loss of No 30 epoch is: 0.07643688470125198 loss of No 40 epoch is: 0.0703967958688736 loss of No 50 epoch is: 0.06483396887779236 loss of No 60 epoch is: 0.05971059575676918 loss of No 70 epoch is: 0.054992228746414185 loss of No 80 epoch is: 0.050646599382162094 loss of No 90 epoch is: 0.04664454981684685 loss of No 100 epoch is: 0.0429585836827755 2 [4.6931477] [1.8465738] 3 [6.764638] [1.921546] 4 [8.836127] [1.9590318] 5 [10.907617] [1.9815233] 6 [12.979107] [1.9965178] 7 [15.050596] [2.0072281] 8 [17.122086] [2.0152607] 9 [19.193575] [2.0215082]