深入理解 PyTorch 中的激活函数
激活函数是神经网络中不可或缺的组成部分,它们引入了非线性,使得神经网络能够学习复杂的模式和特征。在 PyTorch 中,torch.nn 模块提供了多种常用的激活函数,本文将对这些激活函数进行总结和介绍。
1. 什么是激活函数?
激活函数的主要作用是将输入信号映射到输出信号,并引入非线性特性。没有激活函数的神经网络本质上是线性变换的堆叠,无法处理复杂的非线性问题。
激活函数的主要特点包括:
- 非线性:允许网络学习复杂的模式。
- 可微性:支持反向传播算法。
- 数值稳定性:避免梯度爆炸或梯度消失问题。
2. PyTorch 中的激活函数分类
PyTorch 提供了多种激活函数,以下是常见的分类:
2.1 基本激活函数
-
ReLU (Rectified Linear Unit) :
-
定义:
ReLU(x) = max(0, x) -
特点:简单高效,解决了梯度消失问题。
-
示例:
-
import torch.nn as nn
relu = nn.ReLU()
-
Sigmoid:
-
定义:
Sigmoid(x) = 1 / (1 + exp(-x)) -
特点:将输出映射到 (0, 1),适合二分类问题。
-
示例:
sigmoid = nn.Sigmoid()
-
-
Tanh (Hyperbolic Tangent) :
-
定义:
Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) -
特点:将输出映射到 (-1, 1),比 Sigmoid 更适合处理零均值数据。
-
示例:
tanh = nn.Tanh() -
2.2 改进的 ReLU 变体
-
LeakyReLU:
-
定义:
LeakyReLU(x) = x if x > 0 else alpha * x -
特点:允许负值通过,缓解 ReLU 的“死亡神经元”问题。
-
示例:
leaky_relu = nn.LeakyReLU(negative_slope=0.01)
-
-
PReLU (Parametric ReLU) :
-
定义:类似 LeakyReLU,但负斜率是可学习的参数。
-
示例:
prelu = nn.PReLU() -
-
ReLU6:
-
定义:
ReLU6(x) = min(max(0, x), 6) -
特点:限制输出范围,适合移动设备上的量化网络。
-
示例:
relu6 = nn.ReLU6() -
2.3 平滑激活函数
-
Softplus:
-
定义:
Softplus(x) = log(1 + exp(x)) -
特点:平滑版的 ReLU。
-
示例:
-
softplus = nn.Softplus()
-
SiLU (Swish) :
-
定义:
SiLU(x) = x * Sigmoid(x) -
特点:在深度学习中表现优异。
-
示例:
-
silu = nn.SiLU()
-
Mish:
-
定义:
Mish(x) = x * Tanh(Softplus(x)) -
特点:自正则化激活函数,适合深层网络。
-
示例:
-
mish = nn.Mish()
2.4 归一化激活函数
-
Softmax:
-
定义:
Softmax(x_i) = exp(x_i) / sum(exp(x_j)) -
特点:将输出归一化为概率分布,常用于多分类问题。
-
示例:
softmax = nn.Softmax(dim=1)
-
-
LogSoftmax:
-
定义:
LogSoftmax(x) = log(Softmax(x)) -
特点:结合负对数似然损失,数值更稳定。
-
示例:
log_softmax = nn.LogSoftmax(dim=1)
-
3. 激活函数的选择
选择激活函数时需要考虑以下因素:
-
任务类型:
-
网络深度:
-
数值稳定性:
4. 示例代码
以下是一个简单的示例,展示如何在 PyTorch 中使用激活函数:
import torch
import torch.nn as nn
# 定义一个简单的神经网络
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(20, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
# 创建模型并测试
model = SimpleNN()
input_data = torch.randn(5, 10)
output = model(input_data)
print(output)
5. 总结
激活函数是神经网络的核心组件,它们赋予网络非线性能力,使其能够处理复杂的任务。在 PyTorch 中,提供了多种激活函数以满足不同的需求。选择合适的激活函数可以显著提升模型的性能。
希望本文能帮助你更好地理解和使用 PyTorch 中的激活函数
6.补充用法
1. Threshold
公式:
代码示例:
import torch
import torch.nn as nn
m = nn.Threshold(0.1, 20)
input = torch.tensor([-0.5, 0.2, 0.8])
output = m(input)
print(output) # 输出:[20.0, 0.2, 0.8]
2. ReLU
公式:
代码示例:
relu = nn.ReLU()
input = torch.tensor([-1.0, 0.0, 1.0])
output = relu(input)
print(output) # 输出:[0.0, 0.0, 1.0]
3. RReLU
公式: 其中 是从均匀分布 中随机采样的值。
代码示例:
rrelu = nn.RReLU(lower=0.1, upper=0.3)
input = torch.tensor([-1.0, 0.0, 1.0])
output = rrelu(input)
print(output)
4. Hardtanh
公式:
代码示例:
hardtanh = nn.Hardtanh(min_val=-1.0, max_val=1.0)
input = torch.tensor([-2.0, 0.0, 2.0])
output = hardtanh(input)
print(output) # 输出:[-1.0, 0.0, 1.0]
5. ReLU6
公式:
代码示例:
relu6 = nn.ReLU6()
input = torch.tensor([-1.0, 3.0, 7.0])
output = relu6(input)
print(output) # 输出:[0.0, 3.0, 6.0]
6. Sigmoid
公式:
代码示例:
sigmoid = nn.Sigmoid()
input = torch.tensor([-1.0, 0.0, 1.0])
output = sigmoid(input)
print(output) # 输出:[0.2689, 0.5, 0.7311]
7. Hardsigmoid
公式:
代码示例:
hardsigmoid = nn.Hardsigmoid()
input = torch.tensor([-4.0, 0.0, 4.0])
output = hardsigmoid(input)
print(output) # 输出:[0.0, 0.5, 1.0]
8. Tanh
公式:
代码示例:
tanh = nn.Tanh()
input = torch.tensor([-1.0, 0.0, 1.0])
output = tanh(input)
print(output) # 输出:[-0.7616, 0.0, 0.7616]
9. SiLU (Swish)
公式:
代码示例:
silu = nn.SiLU()
input = torch.tensor([-1.0, 0.0, 1.0])
output = silu(input)
print(output)
10. Mish
公式:
代码示例:
mish = nn.Mish()
input = torch.tensor([-1.0, 0.0, 1.0])
output = mish(input)
print(output)
11. Hardswish
公式:
代码示例:
hardswish = nn.Hardswish()
input = torch.tensor([-4.0, 0.0, 4.0])
output = hardswish(input)
print(output)
12. ELU
公式:
代码示例:
elu = nn.ELU(alpha=1.0)
input = torch.tensor([-1.0, 0.0, 1.0])
output = elu(input)
print(output)
13. Softmax
公式:
代码示例:
softmax = nn.Softmax(dim=0)
input = torch.tensor([1.0, 2.0, 3.0])
output = softmax(input)
print(output)