权值初始化(8)

58 阅读1分钟

梯度消失与爆炸

  1. E𝑿∗𝒀=𝑬𝑿∗𝑬𝒀
  2. D𝑿=𝑬X𝟐[𝑬𝑿]𝟐
  3. D𝑿+𝒀=𝑫𝑿+𝑫𝒀
  4. 1.2.3 ⇒ D(XY)=D(X)D(Y)+D(X)[𝑬𝒀]𝟐+D(Y)[𝑬𝑿]𝟐

若E(X)=0,E(Y)=0 D(X*Y)=D(X)*D(Y) image.png

Xavier初始化

方差一致性:保持数据尺度维持在恰当范围,通常方差为1激活函数:饱和函数,如Sigmoid,Tanh

参考文献:《Understanding the difficulty of training deep feedforward neural networks》

Kaiming初始化

方差一致性:保持数据尺度维持在恰当范围,通常方差为1激活函数:ReLU及其变种

《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》

nn.init.calculate_gain

nn.init.calculate_gain 主要功能:计算激活函数的方差变化尺度 主要参数:

  • nonlinearity: 激活函数名称
  • param: 激活函数的参数,如LeakyReLU的negative_slop

image.png

示例代码

# -*- coding: utf-8 -*-
import os
import torch
import random
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


class MLP(nn.Module):
    def __init__(self, neural_num, layers):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):
        for (i, linear) in enumerate(self.linears):
            x = linear(x)
            x = torch.relu(x)

            print("layer:{}, std:{}".format(i, x.std()))
            if torch.isnan(x.std()):
                print("output is nan in {} layers".format(i))
                break

        return x

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))    # normal: mean=0, std=1

                # a = np.sqrt(6 / (self.neural_num + self.neural_num))
                #
                # tanh_gain = nn.init.calculate_gain('tanh')
                # a *= tanh_gain
                #
                # nn.init.uniform_(m.weight.data, -a, a)

                # nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

                # nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))
                nn.init.kaiming_normal_(m.weight.data)

flag = 0
# flag = 1

if flag:
    layer_nums = 100
    neural_nums = 256
    batch_size = 16

    net = MLP(neural_nums, layer_nums)
    net.initialize()

    inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1

    output = net(inputs)
    print(output)

# ======================================= calculate gain =======================================

# flag = 0
flag = 1

if flag:

    x = torch.randn(10000)
    out = torch.tanh(x)

    gain = x.std() / out.std()
    print('gain:{}'.format(gain))

    tanh_gain = nn.init.calculate_gain('tanh')
    print('tanh_gain in PyTorch:', tanh_gain)