神经网络的基本骨架-nn.Module的使用

178 阅读1分钟

一、PyTorch 中的 nn.Module

nn.Module 是 PyTorch 框架中的核心类,用于定义和创建神经网络的基本结构。它可以帮助我们组织和管理网络中的各种组件,比如神经网络的层。在使用 PyTorch 创建神经网络时,通常都会创建一个继承自 nn.Module 的类。

二、为什么使用 nn.Module

  • 组织网络:通过 nn.Module,我们可以简单地定义并组织神经网络的各个层。它提供了一种结构化的方式来把常见的计算组件(如卷积层、线性层等)拼接在一起。
  • 参数管理nn.Module 自动管理所有定义的层和它们的可训练参数,这让我们免去手动处理每个参数。
  • 设备管理:可以轻松地把整个网络模型放到 GPU 或 CPU 上运行

三、使用步骤

1.创建一个新类继承 nn.Module

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()  # 调用父类的初始化方法
        self.fc1 = nn.Linear(784, 128)  # 第一层,全连接层
        self.fc2 = nn.Linear(128, 64)   # 第二层,全连接层
        self.out = nn.Linear(64, 10)    # 输出层,全连接层

2.定义前向传播 forward 方法

def forward(self, x):
    x = torch.flatten(x, 1)  # 展平成一维
    x = torch.relu(self.fc1(x))  # 通过第一层并进行 ReLU 激活
    x = torch.relu(self.fc2(x))  # 通过第二层并进行 ReLU 激活
    x = self.out(x)  # 计算输出
    return x

3.使用模型

# 创建模型实例
model = SimpleNet()

# 将模型移到 GPU(如果可用)
if torch.cuda.is_available():
    model = model.to('cuda')

四、服务器实操

image.png