优化器(9)

124 阅读2分钟

什么是优化器

pytorch的优化器:管理并更新模型中可学习参数的值,使得模型输出更接近真实标签 导数:函数在指定坐标轴上的变化率 方向导数:指定方向上的变化率 梯度:一个向量,方向为方向导数取得最大值的方向

优化器的基本属性

基本属性:

  • defaults:优化器超参数
  • state:参数的缓存,如momentum的缓存
  • params_groups:管理的参数组
  • _step_count:记录更新次数,学习率调整中使用
class Optimizer(object):
    def __init__(self, params, defaults):
        self.defaults = defaults
        self.state = defaultdict(dict)
        self.param_groups = []
        param_groups = [{'params': param_groups}]

读者可以自己调式进去pytorch源码一探究尽

优化器的基本方法

step

执行一步更新

import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
import torch
import torch.optim as optim
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))

optimizer = optim.SGD([weight], lr=0.1)

flag = 1
if flag:
    print("weight before step:{}".format(weight.data))
    optimizer.step()        # 修改lr=1 0.1观察结果
    print("weight after step:{}".format(weight.data))

zero_grad

zero_grad():清空所管理参数的梯度

pytorch特性:张量梯度不自动清零

class Optimizer(object):
    def zero_grad(self):
        for group in self.param_groups:
        for p in group['params']:
        if p.grad is not None:
        p.grad.detach_()
        p.grad.zero_()

调用如下:

flag = 0
# flag = 1
if flag:

    print("weight before step:{}".format(weight.data))
    optimizer.step()        # 修改lr=1 0.1观察结果
    print("weight after step:{}".format(weight.data))

    print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))

    print("weight.grad is {}\n".format(weight.grad))
    optimizer.zero_grad()
    print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))

add_param_group

添加参数组

class Optimizer(object):
    def add_param_group(self, param_group):
        for group in self.param_groups:
            param_set.update(set(group['params’]))
        self.param_groups.append(param_group)
flag = 0
# flag = 1
if flag:
    print("optimizer.param_groups is\n{}".format(optimizer.param_groups))

    w2 = torch.randn((3, 3), requires_grad=True)

    optimizer.add_param_group({"params": w2, 'lr': 0.0001})

    print("optimizer.param_groups is\n{}".format(optimizer.param_groups))

state_dict和load_state_dict

state_dict():获取优化器当前状态信息字典 load_state_dict():加载状态信息字典

flag = 0
# flag = 1
if flag:

    optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
    opt_state_dict = optimizer.state_dict()

    print("state_dict before step:\n", opt_state_dict)

    for i in range(10):
        optimizer.step()

    print("state_dict after step:\n", optimizer.state_dict())

    torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
flag = 0
# flag = 1
if flag:

    optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
    state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))

    print("state_dict before load state:\n", optimizer.state_dict())
    optimizer.load_state_dict(state_dict)
    print("state_dict after load state:\n", optimizer.state_dict())

优化器的种类

SGD

Momentum(动量,冲量):结合当前梯度与上一次更新信息,用于当前更新 image.png optim.SGD 主要参数:

  • params:管理的参数组
  • lr:初始学习率
  • momentum:动量系数,贝塔
  • weight_decay:L2正则化系数
  • nesterov:是否采用NAG

image.png

其他的优化器选择

  1. optim.SGD:随机梯度下降法
  2. optim.Adagrad:自适应学习率梯度下降法
  3. optim.RMSprop:Adagrad的改进
  4. optim.Adadelta:Adagrad的改进
  5. optim.Adam:RMSprop结合Momentum
  6. optim.Adamax:Adam增加学习率上限
  7. optim.SparseAdam:稀疏版的Adam
  8. optim.ASGD:随机平均梯度下降
  9. optim.Rprop:弹性反向传播
  10. optim.LBFGS:BFGS的改进

推荐感兴趣的读者自行搜索相关的论文进行研究