深度学习——线性回归实现笔记(上)

1,386 阅读3分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第17天,点击查看活动详情 这个是我个人学习笔记,跟着b站沐神学习,链接: 08 线性回归 + 基础优化算法【动手学深度学习v2】_哔哩哔哩_bilibili

我仅仅对代码进行一些解读,发现有解读不对的地方,欢迎大家来评论区讨论

我将其分为上下两篇来讲解

🎄生成数据集 

我们将从零开始实现整个方法,包括数据流水线、模型、损失函数和小批量随机梯度下降优化器

🎋第一步:导包 

import torch
import random
from d2l import torch as d2l

image.gif

若显示报错,没有d2l这个模块,在jupyter notebook中输入

!pip install -U d2l

image.gif

 🎋第二步:构造函数

image.png

image.png

def synthetic_data(w,b,num_examples):
    '''生成y = wx + b + 噪声。  '''
    x = torch.normal(0,1,(num_examples,len(w)))
    y = torch.matmul(x,w) + b
    y += torch.normal(0,0.01,y.shape)
    return x,y.reshape((-1,1))
true_w = torch.tensor([2,-3.4])
true_b = 4.2
features,labels = synthetic_data(true_w,true_b,1000)
print("features:",features[0],"\nlabels:",labels[0])

image.gif

1)此处b是个一维向量,当matmul的第一个参数是2维向量,第2个参数是一维向量时,返回的是矩阵和向量的乘积,结果是向量,因此,y需要reshape

2)reshape中-1表示自动计算,1表示固定,即列向量为1

3)创建一个形状为(3,4)的张量。 其中的每个元素都从均值为0、标准差为1的标准高斯分布(正态分布)中随机采样。

normal(0, 1, size=(3, 4))

🎋结果 

image.png

 🎋第三步:画图看一下是不是线性相关

d2l.set_figsize()
d2l.plt.scatter(features[:,-1].detach().numpy(),
               labels.detach().numpy(),1)

image.gif

1)detach()分离出数值,不再含有梯度

2)scatter()函数最后的一个1是绘制点直径的大小,如果改成50会看到一个个点非常粗

3)总的来说 生成第二个特征features[:, 1]labels的散点图, 可以直观观察到两者之间的线性关系

image.gifimage.png

🎄读取数据集 

🎋第一步:定义函数

imageimage.gif编辑

训练模型时要对数据集进行遍历,每次抽取一小批量样本,并使用它们来更新我们的模型。 由于这个过程是训练机器学习算法的基础,所以有必要定义一个函数, 该函数能打乱数据集中的样本并以小批量方式获取数据。

在下面的代码中,我们定义一个data_iter函数, 该函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量。 每个小批量包含一组特征和标签

def data_iter(batch_size,features,labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    #这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices)
    for i in range(0,num_examples,batch_size):
        batch_indices = torch.tensor(
        indices[i:min(i+batch_size,num_examples)])
        yield features[batch_indices],labels[batch_indices]

image.gif

1)只是indices这个list被打乱了,features和labels都是顺序的,用循环才能随机地放进去

2)min的作用     不让访问越界    list超出会报错,out of index

3)通过 yield,创建生成器 链接:【python】基础知识巩固(一)_heart_6662的博客-CSDN博客

我们不再需要编写读文件的迭代类,就可以轻松实现文件读取

🎋第二步:感受一下小批量运算

我们直观感受一下小批量运算:读取第一个小批量数据样本并打印。 每个批量的特征维度显示批量大小和输入特征数。 同样的,批量的标签形状与batch_size相等

batch_size = 10
for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

image.gif

image.png