一文讲清 torch、torch.nn、torch.nn.functional 及 nn.Module

341 阅读7分钟

生活化比喻 + 代码示例,彻底讲清楚:

torchtorch.nntorch.nn.functional 三者的关系和作用


🧩 一句话总结:

模块作用类比
torch提供张量计算、自动求导、设备管理等基础能力厨房 + 基础食材 + 灶台
torch.nn提供“可训练的神经网络层/模块”,自带参数预制菜模具(带配方、带调料包)
torch.nn.functional提供“无状态的函数式操作”,不带参数厨具/调料瓶(酱油、刀、锅)—— 用完即走,不保存状态

🍳 1. torch —— 基础张量库(厨房+食材)

这是 PyTorch 的核心计算引擎,提供:

  • 张量(Tensor)创建与运算(加减乘除、矩阵乘、转置…)
  • 自动微分(.backward()
  • GPU支持(.to('cuda')
  • 随机数、数学函数等

📌 类比:

你走进厨房,里面有菜刀、砧板、灶台、盐、油、面粉、鸡蛋…
—— 这就是 torch,提供基础“烹饪能力”。

✅ 示例:

import torch

# 创建张量(食材)
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])

# 基础运算(切菜、炒菜)
z = x + y          # [5., 7., 9.]
w = torch.sin(x)   # [sin(1), sin(2), sin(3)]

# 自动求导(记录菜谱步骤)
x.requires_grad_(True)
loss = (x * 2).sum()
loss.backward()    # 自动计算梯度 → x.grad = [2, 2, 2]

print(x.grad)      # tensor([2., 2., 2.])

🧱 2. torch.nn —— 神经网络模块(预制菜模具 ✅ 带参数)

这是构建神经网络的**“积木块”**,每个模块:

  • 自带可学习参数(如权重、偏置)
  • 可保存、加载、嵌套
  • 适合构建复杂模型(如 nn.Linear, nn.Conv2d, nn.Transformer

📌 类比:

你买了一个“红烧肉预制菜包”,里面有:

  • 调料包(参数 weight, bias)
  • 使用说明(forward 方法)
  • 用完还能留着下次用(可复用、可保存)

✅ 示例:

import torch.nn as nn

# 定义一个线性层(带参数!)
linear_layer = nn.Linear(in_features=3, out_features=2)
# 内部自动创建了 weight (2x3) 和 bias (2)

x = torch.randn(5, 3)  # batch=5, 输入3维
output = linear_layer(x)  # 输出 shape: [5, 2]

print("权重:", linear_layer.weight.shape)  # torch.Size([2, 3])
print("偏置:", linear_layer.bias.shape)    # torch.Size([2])

# 这个层可以被优化器更新
optimizer = torch.optim.SGD(linear_layer.parameters(), lr=0.01)

nn.Linear, nn.Conv2d, nn.Embedding, nn.LSTM, nn.TransformerEncoder 都属于 torch.nn


🔧 3. torch.nn.functional —— 函数式操作(厨具/调料瓶 ⚡ 无状态)

这是无参数的函数集合,提供和 nn 模块相同的功能,但:

  • 不保存参数
  • 每次调用都是“临时操作”
  • 更灵活,适合在 forward 中组合使用

📌 类比:

你手边的“酱油瓶”、“炒锅”、“漏勺” ——

  • 每次用完放回原位,不保存“上次用了多少酱油”
  • 想用就拿,灵活组合

✅ 示例:

import torch.nn.functional as F

x = torch.randn(5, 3)

# 用 F.linear 实现和 nn.Linear 一样的计算,但要手动传权重
weight = torch.randn(2, 3)
bias = torch.randn(2)
output = F.linear(x, weight, bias)  # 无内部状态,纯函数

# 常用函数:
y = F.relu(x)           # 激活函数
z = F.softmax(x, dim=1) # 归一化
loss = F.cross_entropy(logits, labels)  # 损失函数

F.relu, F.softmax, F.cross_entropy, F.dropout, F.embedding 都是函数式操作


🆚 对比:什么时候用 nn.XXX vs F.xxx

场景推荐使用原因
构建模型层(带参数)nn.Linear, nn.Conv2d自动管理参数,方便优化和保存
激活函数、损失函数、dropout等F.relu, F.cross_entropy, F.dropout无参数,灵活,避免重复创建模块
forward 中临时计算F.xxx轻量、高效、不保存状态

✅ 最佳实践示例:

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)   # ✅ 用 nn,因为要保存权重
        self.fc2 = nn.Linear(20, 5)

    def forward(self, x):
        x = F.relu(self.fc1(x))        # ✅ 用 F.relu,因为无参数
        x = F.dropout(x, p=0.5, training=self.training)  # ✅ 用 F.dropout
        x = self.fc2(x)
        return F.log_softmax(x, dim=1) # ✅ 用 F.log_softmax

🧠 记忆口诀:

  • torch = 基础食材 + 厨房设备
  • nn = 预制菜包(带调料,可复用)
  • F = 调料瓶/厨具(随手用,不保存)

👉 带参数 → 用 nn;无参数 → 用 F


📊 三者关系图:

                     torch
                      │
                      ├── 张量计算、自动求导、设备管理
                      │
           ┌──────────┴──────────┐
           ▼                     ▼
      torch.nn           torch.nn.functional
  (带参数的模块/层)        (无参数的函数)
   Linear, Conv2d          relu, softmax, dropout
   Embedding, LSTM         cross_entropy, embedding

✅ 总结卡片:

模块用途是否带参数典型成员适用场景
torch基础张量运算、自动微分tensor, sin, matmul, backward()所有底层计算
torch.nn可训练网络层Linear, Conv2d, Transformer模型结构定义
torch.nn.functional无状态函数操作relu, softmax, cross_entropyforward 中灵活调用

现在你彻底搞懂了这三个核心模块的区别和用法!

🧠 记住:

  • 想“保存参数、构建模型” → 用 torch.nn
  • 想“临时计算、灵活组合” → 用 torch.nn.functional
  • 想“做数学、求导、转设备” → 用 torch

太好了!你问到了 PyTorch 的灵魂组件 —— nn.Module

我们继续用 🍳 生活化比喻 + 代码示例 + 结构图,彻底讲清楚:

nn.Module 是什么?它和 torchnnF 有什么关系?为什么所有模型都要继承它?


🧩 一句话总结:

nn.Module 是所有神经网络组件的“基类” —— 它像一个“智能收纳盒”,帮你管理参数、子模块、设备、训练/推理状态。


🏗️ 类比:nn.Module = 一个“智能乐高底座 + 收纳管理系统”

想象你要搭一个机器人(神经网络):

  • 你有各种零件:马达(Linear层)、传感器(Conv层)、电池(参数)
  • 你需要一个底座把它们组装起来
  • 你还需要一个管理系统:自动记录哪些是可训练零件、一键搬去GPU、一键保存所有零件…

👉 nn.Module 就是这个底座 + 管理系统!


✅ 核心功能(为什么必须用它?)

功能说明举例
🔧 自动管理参数所有 .weight, .bias 自动注册,可被优化器更新model.parameters() 返回所有参数
🧩 嵌套子模块可以包含其他 Module(如 Linear, Conv, 甚至自定义模块)Encoder 包含多个 Attention 层
🖥️ 设备迁移一键搬去 GPU/CPUmodel.to('cuda')
🎚️ 训练/推理模式切换自动管理 Dropout、BatchNorm 行为model.train() / model.eval()
💾 保存/加载模型一键保存所有参数和结构torch.save(model.state_dict(), ...)
🔄 前向传播定义你只需写 forward(),其余自动处理output = model(input)

🧱 代码示例:手写一个简单模型

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):  # 👈 继承 nn.Module
    def __init__(self):
        super().__init__()  # 初始化父类
        # 定义“带参数”的子模块 → 自动被 nn.Module 管理!
        self.fc1 = nn.Linear(10, 20)   # 参数自动注册
        self.fc2 = nn.Linear(20, 5)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # 使用“无参数”函数 → 用 F
        x = F.relu(self.fc1(x))
        x = self.dropout(x)            # Dropout 是模块,有训练/推理状态
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# 实例化模型
model = SimpleNet()

# 🎯 nn.Module 的魔法开始!
print("✅ 所有参数:")
for name, param in model.named_parameters():
    print(f"  {name}: {param.shape}")

# 🖥️ 一键搬去GPU
model.to('cuda')

# 🎚️ 切换模式
model.train()   # Dropout 生效
model.eval()    # Dropout 关闭

# 💾 保存模型
torch.save(model.state_dict(), "mymodel.pth")

🆚 nn.Module vs nn vs F 关系图

                     torch
                      │
                      ├── 基础张量、自动求导
                      │
           ┌──────────┴──────────┐
           ▼                     ▼
      torch.nn           torch.nn.functional
         │                       │
         │ (包含可训练层)         │ (纯函数,无状态)
         ▼                       ▼
    Linear, Conv2d           relu, softmax
    Embedding, LSTM          dropout, cross_entropy
         │
         │ 所有这些层都继承自 👇
         ▼
    nn.Module ← 你自定义的模型也要继承它!

nn.Linear, nn.Conv2d, nn.Transformer 都是 nn.Module 的子类
✅ 你写的 class MyModel(nn.Module) 也是!


🧠 为什么必须继承 nn.Module

如果你不继承:

class BadNet:  # ❌ 没有继承 nn.Module
    def __init__(self):
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 5)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

会发生什么?

model = BadNet()
print(list(model.parameters()))  # ❌ 返回空列表!优化器找不到参数!
model.to('cuda')                 # ❌ 报错!没有 .to() 方法
torch.save(model.state_dict(), ...)  # ❌ 没有 state_dict()

你的模型“失控”了!参数管不了、设备搬不动、无法保存!


🧩 nn.Module 内部结构(简化版)

当你写:

self.fc1 = nn.Linear(10, 20)

nn.Module 在背后做了:

  1. fc1 注册为子模块(_modules 字典)
  2. fc1.weight, fc1.bias 注册为参数(_parameters 字典)
  3. 递归管理所有子模块的参数和设备

→ 所以你才能:

model.parameters()        # 返回所有参数
model.to('cuda')          # 一键搬所有参数和缓冲区去GPU
model.state_dict()        # 返回所有参数的字典
model.load_state_dict(...) # 一键加载

🎯 最佳实践:什么时候该继承 nn.Module

场景是否继承 nn.Module举例
构建完整模型或可复用模块✅ 必须Transformer, ResNetBlock, MyClassifier
临时计算、无参数操作❌ 不用F.relu, F.softmax 等函数
自定义带参数层✅ 必须自定义 Attention、自定义 Embedding

✅ 总结卡片:

项目说明
是什么所有神经网络模块的基类
核心作用自动管理参数、子模块、设备、模式、保存/加载
必须继承吗?如果你要构建可训练、可保存、可迁移的模块 → 必须!
典型子类nn.Linear, nn.Conv2d, nn.Transformer, 你写的任何模型
关键方法forward(), parameters(), to(), train(), eval(), state_dict()

🧠 记忆口诀:

“想当AI乐高大师?先继承 nn.Module!”

  • 它是积木底座
  • 它是参数管家
  • 它是设备搬运工
  • 它是模型保险箱

现在你彻底理解了 nn.Module 的核心地位 —— 它是 PyTorch 生态的“粘合剂”和“管理者”。

没有它,你的模型就是一盘散沙 🏖️
有了它,你可以构建任何复杂网络 🏗️🚀