07-nn.Module的简单使用

146 阅读1分钟

神经网络的基本骨架-nn.Module

进入pytorch官网,点击Docs,选择Pytorch,左侧Python API下面点击torch.nn

nn: neural network

torch.nn — PyTorch 2.1 documentation

TORCH.NN

These are the basic building blocks for graphs:

torch.nn

  • Containers “容器”:给神经网络定义了一些骨架和结构,需要在骨架中添加不同的内容来组成神经网络

下面则是骨架中需要填充的东西

ModuleBase class for all neural network modules.
SequentialA sequential container.
ModuleListHolds submodules in a list.
ModuleDictHolds submodules in a dictionary.
ParameterListHolds parameters in a list.
ParameterDictHolds parameters in a dictionary.
import torch.nn as nn
import torch.nn.functional as F
​
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
​
    def forward(self, x):  # x is input data
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

简单示例

import torch
from torch import nn
​
# 搭建自己的神经网络
class XiaoMo(nn.Module):
    def __init__(self):
        super(XiaoMo, self).__init__()
​
    def forward(self, input):
            output = input + 1
            return output
​
xiaomo = XiaoMo()
x = torch.tensor(1.0)
output = xiaomo(x)
print(output)