什么是优化器
pytorch的优化器:管理并更新模型中可学习参数的值,使得模型输出更接近真实标签 导数:函数在指定坐标轴上的变化率 方向导数:指定方向上的变化率 梯度:一个向量,方向为方向导数取得最大值的方向
优化器的基本属性
基本属性:
- defaults:优化器超参数
- state:参数的缓存,如momentum的缓存
- params_groups:管理的参数组
- _step_count:记录更新次数,学习率调整中使用
class Optimizer(object):
def __init__(self, params, defaults):
self.defaults = defaults
self.state = defaultdict(dict)
self.param_groups = []
param_groups = [{'params': param_groups}]
读者可以自己调式进去pytorch源码一探究尽
优化器的基本方法
step
执行一步更新
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
import torch
import torch.optim as optim
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))
optimizer = optim.SGD([weight], lr=0.1)
flag = 1
if flag:
print("weight before step:{}".format(weight.data))
optimizer.step() # 修改lr=1 0.1观察结果
print("weight after step:{}".format(weight.data))
zero_grad
zero_grad():清空所管理参数的梯度
pytorch特性:张量梯度不自动清零
class Optimizer(object):
def zero_grad(self):
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
调用如下:
flag = 0
# flag = 1
if flag:
print("weight before step:{}".format(weight.data))
optimizer.step() # 修改lr=1 0.1观察结果
print("weight after step:{}".format(weight.data))
print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))
print("weight.grad is {}\n".format(weight.grad))
optimizer.zero_grad()
print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))
add_param_group
添加参数组
class Optimizer(object):
def add_param_group(self, param_group):
for group in self.param_groups:
param_set.update(set(group['params’]))
self.param_groups.append(param_group)
flag = 0
# flag = 1
if flag:
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
w2 = torch.randn((3, 3), requires_grad=True)
optimizer.add_param_group({"params": w2, 'lr': 0.0001})
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
state_dict和load_state_dict
state_dict():获取优化器当前状态信息字典 load_state_dict():加载状态信息字典
flag = 0
# flag = 1
if flag:
optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
opt_state_dict = optimizer.state_dict()
print("state_dict before step:\n", opt_state_dict)
for i in range(10):
optimizer.step()
print("state_dict after step:\n", optimizer.state_dict())
torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
flag = 0
# flag = 1
if flag:
optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
print("state_dict before load state:\n", optimizer.state_dict())
optimizer.load_state_dict(state_dict)
print("state_dict after load state:\n", optimizer.state_dict())
优化器的种类
SGD
Momentum(动量,冲量):结合当前梯度与上一次更新信息,用于当前更新
optim.SGD
主要参数:
- params:管理的参数组
- lr:初始学习率
- momentum:动量系数,贝塔
- weight_decay:L2正则化系数
- nesterov:是否采用NAG
其他的优化器选择
- optim.SGD:随机梯度下降法
- optim.Adagrad:自适应学习率梯度下降法
- optim.RMSprop:Adagrad的改进
- optim.Adadelta:Adagrad的改进
- optim.Adam:RMSprop结合Momentum
- optim.Adamax:Adam增加学习率上限
- optim.SparseAdam:稀疏版的Adam
- optim.ASGD:随机平均梯度下降
- optim.Rprop:弹性反向传播
- optim.LBFGS:BFGS的改进
推荐感兴趣的读者自行搜索相关的论文进行研究