一文讲清 nn.Module 主要成员函数作用

72 阅读5分钟

最通俗的语言 + 生活比喻 + 代码示例,向初学者彻底讲清楚:

🎯 nn.Module 主要成员函数的作用与用法


🧩 一句话总结:

nn.Module 是所有神经网络的“基类”,它提供了一套“管理工具箱”,帮你自动管理参数、设备、训练状态、保存加载等 —— 你只需要专注写 __init__forward


🏗️ 类比:nn.Module = 一个“智能机器人组装工厂”

你是一个机器人设计师:

  • 你负责设计机器人的“零件”(__init__
  • 你负责设计机器人的“工作流程”(forward
  • 工厂(nn.Module)自动帮你:
    • 管理所有零件(参数)
    • 给零件编号(命名)
    • 搬家去新车间(GPU)
    • 打包发货(保存模型)
    • 收货组装(加载模型)
    • 切换工作模式(训练/推理)

→ 你不用操心这些杂事,专心设计机器人就行!


✅ 核心成员函数详解(初学者必懂的6个)


1️⃣ __init__(self) —— “零件采购与组装”

作用:初始化你的网络,定义所有层和参数。

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()  # 👈 必须调用父类初始化!
        self.fc1 = nn.Linear(10, 20)   # 定义第一层
        self.fc2 = nn.Linear(20, 5)    # 定义第二层
        self.dropout = nn.Dropout(0.5) # 定义 dropout

📌 必须做

  • 调用 super().__init__()
  • self.xxx = ... 定义层(这些层必须是 nn.Module 子类或 Parameter/Buffer

比喻:你告诉工厂:“我要买一个马达(fc1)、一个传感器(fc2)、一个保险丝(dropout)”


2️⃣ forward(self, x) —— “机器人工作流程”

作用:定义数据如何流过网络 —— 这是你最常写的函数!

def forward(self, x):
    x = F.relu(self.fc1(x))  # 第一层 + 激活
    x = self.dropout(x)      # dropout
    x = self.fc2(x)          # 输出层
    return F.log_softmax(x, dim=1)

📌 注意

  • 不要直接调用 model.forward(x) → 用 model(x)
  • F.xxx(如 F.relu)做无状态操作
  • self.xxx(x) 调用有状态层(如 self.dropout(x)

比喻:你写操作手册:“先启动马达,再过保险丝,最后输出信号”


3️⃣ parameters() / named_parameters() —— “列出所有可训练零件”

作用:获取所有可训练参数,用于优化器。

model = MyNet()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 查看参数名和形状
for name, param in model.named_parameters():
    print(name, param.shape)
# 输出:
# fc1.weight torch.Size([20, 10])
# fc1.bias   torch.Size([20])
# fc2.weight torch.Size([5, 20])
# fc2.bias   torch.Size([5])

比喻:工厂给你一份“可更换零件清单”,方便你找维修工(优化器)来更新


4️⃣ to(device) —— “把整个机器人搬去新车间”

作用:把模型和所有参数/缓冲区搬到指定设备(CPU/GPU)。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)  # 👈 一行搞定!所有参数和buffer自动搬家!

比喻:你说“把整个机器人搬到新车间”,工厂自动把所有零件、螺丝、电线一起搬过去!


5️⃣ train() / eval() —— “切换工作模式”

作用:切换模型的“训练模式”或“推理模式”,影响 Dropout、BatchNorm 等层的行为。

model.train()   # 启用 Dropout、BatchNorm 更新统计量
# ... 训练代码 ...

model.eval()    # 关闭 Dropout、BatchNorm 用固定统计量
with torch.no_grad():
    output = model(x)  # 推理

比喻

  • train() → “机器人进入训练模式:允许随机关闭零件(Dropout),记录新数据(BatchNorm)”
  • eval() → “机器人进入工作模式:所有零件全开,用历史数据”

6️⃣ state_dict() / load_state_dict() —— “打包发货”与“收货组装”

作用:保存和加载模型的所有参数和缓冲区。

# 保存模型
torch.save(model.state_dict(), 'model.pth')

# 加载模型(先初始化相同结构)
model = MyNet()
model.load_state_dict(torch.load('model.pth'))
model.eval()  # 别忘了切换模式!

比喻

  • state_dict() → “把机器人拆成零件包,贴上标签(参数名)”
  • load_state_dict() → “按标签把零件装回去”

🆚 附:buffers() vs parameters()

函数包含内容是否可训练
parameters()所有 nn.Parameter✅ 是
buffers()所有 register_buffer❌ 否
for name, buf in model.named_buffers():
    print("Buffer:", name, buf.shape)
# 比如:位置编码、running_mean 等

nn.Parameter 是用来“绑定”那些“要训练、要更新”的张量(比如权重、偏置); register_buffer 是用来“绑定”那些“不用训练、但要跟着模型走”的张量(比如位置编码、统计量)。

register_buffer 函数的第一个name参数 可以是任意合法的 Python 字符串(只要不和已有属性冲突),它的作用是 —— 给这个 buffer 起个“名字”,以后通过 self.name 来访问它,并在保存/加载模型时作为字典的 key。


📊 初学者速查表

函数作用何时用示例
__init__定义网络结构一次(类初始化时)self.fc1 = nn.Linear(...)
forward定义前向计算每次预测/训练时return F.relu(self.fc1(x))
parameters()获取可训练参数初始化优化器时optim.SGD(model.parameters(), ...)
to(device)搬迁模型到设备训练/推理前model.to('cuda')
train()切换训练模式训练循环开始前model.train()
eval()切换推理模式测试/验证前model.eval()
state_dict()获取模型状态保存模型时torch.save(model.state_dict(), ...)
load_state_dict()加载模型状态加载模型时model.load_state_dict(torch.load(...))

💡 最佳实践口诀:

“init 定结构,forward 写流程;
parameters 给优化器,to(device) 搬全家;
train/eval 切模式,state_dict 保存它!”


🚫 初学者常见错误:

  1. ❌ 忘记 super().__init__()
  2. ❌ 在 forward 里定义新层(应该在 __init__ 里定义)
  3. ❌ 用 model.forward(x) 而不是 model(x)
  4. ❌ 忘记 model.train() / model.eval()
  5. ❌ 保存整个模型 torch.save(model, ...) → 推荐只保存 state_dict()

✅ 完整示例:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x  # 返回 logits,不在这里 softmax!

# 使用
model = SimpleClassifier(10, 20, 3)
model.to('cuda')  # 搬家

# 训练
optimizer = torch.optim.Adam(model.parameters())
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    logits = model(x_train)
    loss = F.cross_entropy(logits, y_train)
    loss.backward()
    optimizer.step()

# 推理
model.eval()
with torch.no_grad():
    logits = model(x_test)
    preds = logits.argmax(dim=1)

🎓 总结:

nn.Module 就像一个“贴心管家”,帮你处理了神经网络中 80% 的杂务 —— 你只需要:

  1. __init__ 里“买零件”
  2. forward 里“写说明书”

其他的,交给它!

掌握这6个核心函数,你就掌握了 PyTorch 模型的“任督二脉” 🧘‍♂️