用 最通俗的语言 + 生活比喻 + 代码示例,向初学者彻底讲清楚:
🎯
nn.Module主要成员函数的作用与用法
🧩 一句话总结:
nn.Module是所有神经网络的“基类”,它提供了一套“管理工具箱”,帮你自动管理参数、设备、训练状态、保存加载等 —— 你只需要专注写__init__和forward!
🏗️ 类比:nn.Module = 一个“智能机器人组装工厂”
你是一个机器人设计师:
- 你负责设计机器人的“零件”(
__init__) - 你负责设计机器人的“工作流程”(
forward) - 工厂(
nn.Module)自动帮你:- 管理所有零件(参数)
- 给零件编号(命名)
- 搬家去新车间(GPU)
- 打包发货(保存模型)
- 收货组装(加载模型)
- 切换工作模式(训练/推理)
→ 你不用操心这些杂事,专心设计机器人就行!
✅ 核心成员函数详解(初学者必懂的6个)
1️⃣ __init__(self) —— “零件采购与组装”
作用:初始化你的网络,定义所有层和参数。
class MyNet(nn.Module):
def __init__(self):
super().__init__() # 👈 必须调用父类初始化!
self.fc1 = nn.Linear(10, 20) # 定义第一层
self.fc2 = nn.Linear(20, 5) # 定义第二层
self.dropout = nn.Dropout(0.5) # 定义 dropout
📌 必须做:
- 调用
super().__init__() - 用
self.xxx = ...定义层(这些层必须是nn.Module子类或Parameter/Buffer)
✅ 比喻:你告诉工厂:“我要买一个马达(fc1)、一个传感器(fc2)、一个保险丝(dropout)”
2️⃣ forward(self, x) —— “机器人工作流程”
作用:定义数据如何流过网络 —— 这是你最常写的函数!
def forward(self, x):
x = F.relu(self.fc1(x)) # 第一层 + 激活
x = self.dropout(x) # dropout
x = self.fc2(x) # 输出层
return F.log_softmax(x, dim=1)
📌 注意:
- 不要直接调用
model.forward(x)→ 用model(x) - 用
F.xxx(如F.relu)做无状态操作 - 用
self.xxx(x)调用有状态层(如self.dropout(x))
✅ 比喻:你写操作手册:“先启动马达,再过保险丝,最后输出信号”
3️⃣ parameters() / named_parameters() —— “列出所有可训练零件”
作用:获取所有可训练参数,用于优化器。
model = MyNet()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 查看参数名和形状
for name, param in model.named_parameters():
print(name, param.shape)
# 输出:
# fc1.weight torch.Size([20, 10])
# fc1.bias torch.Size([20])
# fc2.weight torch.Size([5, 20])
# fc2.bias torch.Size([5])
✅ 比喻:工厂给你一份“可更换零件清单”,方便你找维修工(优化器)来更新
4️⃣ to(device) —— “把整个机器人搬去新车间”
作用:把模型和所有参数/缓冲区搬到指定设备(CPU/GPU)。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device) # 👈 一行搞定!所有参数和buffer自动搬家!
✅ 比喻:你说“把整个机器人搬到新车间”,工厂自动把所有零件、螺丝、电线一起搬过去!
5️⃣ train() / eval() —— “切换工作模式”
作用:切换模型的“训练模式”或“推理模式”,影响 Dropout、BatchNorm 等层的行为。
model.train() # 启用 Dropout、BatchNorm 更新统计量
# ... 训练代码 ...
model.eval() # 关闭 Dropout、BatchNorm 用固定统计量
with torch.no_grad():
output = model(x) # 推理
✅ 比喻:
train()→ “机器人进入训练模式:允许随机关闭零件(Dropout),记录新数据(BatchNorm)”eval()→ “机器人进入工作模式:所有零件全开,用历史数据”
6️⃣ state_dict() / load_state_dict() —— “打包发货”与“收货组装”
作用:保存和加载模型的所有参数和缓冲区。
# 保存模型
torch.save(model.state_dict(), 'model.pth')
# 加载模型(先初始化相同结构)
model = MyNet()
model.load_state_dict(torch.load('model.pth'))
model.eval() # 别忘了切换模式!
✅ 比喻:
state_dict()→ “把机器人拆成零件包,贴上标签(参数名)”load_state_dict()→ “按标签把零件装回去”
🆚 附:buffers() vs parameters()
| 函数 | 包含内容 | 是否可训练 |
|---|---|---|
parameters() | 所有 nn.Parameter | ✅ 是 |
buffers() | 所有 register_buffer | ❌ 否 |
for name, buf in model.named_buffers():
print("Buffer:", name, buf.shape)
# 比如:位置编码、running_mean 等
nn.Parameter是用来“绑定”那些“要训练、要更新”的张量(比如权重、偏置);register_buffer是用来“绑定”那些“不用训练、但要跟着模型走”的张量(比如位置编码、统计量)。
register_buffer函数的第一个name参数 可以是任意合法的 Python 字符串(只要不和已有属性冲突),它的作用是 —— 给这个 buffer 起个“名字”,以后通过self.name来访问它,并在保存/加载模型时作为字典的 key。
📊 初学者速查表
| 函数 | 作用 | 何时用 | 示例 |
|---|---|---|---|
__init__ | 定义网络结构 | 一次(类初始化时) | self.fc1 = nn.Linear(...) |
forward | 定义前向计算 | 每次预测/训练时 | return F.relu(self.fc1(x)) |
parameters() | 获取可训练参数 | 初始化优化器时 | optim.SGD(model.parameters(), ...) |
to(device) | 搬迁模型到设备 | 训练/推理前 | model.to('cuda') |
train() | 切换训练模式 | 训练循环开始前 | model.train() |
eval() | 切换推理模式 | 测试/验证前 | model.eval() |
state_dict() | 获取模型状态 | 保存模型时 | torch.save(model.state_dict(), ...) |
load_state_dict() | 加载模型状态 | 加载模型时 | model.load_state_dict(torch.load(...)) |
💡 最佳实践口诀:
“init 定结构,forward 写流程;
parameters 给优化器,to(device) 搬全家;
train/eval 切模式,state_dict 保存它!”
🚫 初学者常见错误:
- ❌ 忘记
super().__init__() - ❌ 在
forward里定义新层(应该在__init__里定义) - ❌ 用
model.forward(x)而不是model(x) - ❌ 忘记
model.train()/model.eval() - ❌ 保存整个模型
torch.save(model, ...)→ 推荐只保存state_dict()
✅ 完整示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleClassifier(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, num_classes)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x # 返回 logits,不在这里 softmax!
# 使用
model = SimpleClassifier(10, 20, 3)
model.to('cuda') # 搬家
# 训练
optimizer = torch.optim.Adam(model.parameters())
model.train()
for epoch in range(100):
optimizer.zero_grad()
logits = model(x_train)
loss = F.cross_entropy(logits, y_train)
loss.backward()
optimizer.step()
# 推理
model.eval()
with torch.no_grad():
logits = model(x_test)
preds = logits.argmax(dim=1)
🎓 总结:
nn.Module 就像一个“贴心管家”,帮你处理了神经网络中 80% 的杂务 —— 你只需要:
- 在
__init__里“买零件” - 在
forward里“写说明书”
其他的,交给它!
掌握这6个核心函数,你就掌握了 PyTorch 模型的“任督二脉” 🧘♂️