深入浅出 Pytorch 系列 — 激活函数(下)

815 阅读7分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第12天,点击查看活动详情

分析激活函数的作用

在实现了将激活函数进行可视化后,目标转移到进一步深入研究激活函数的作用,接下来将借助一个简单神经网络来在FashionMNIST 数据集上进行训练。然后查看模型各个侧面,包括梯度传递的性能

class BaseNetwork(nn.Module):

    def __init__(self, act_fn, input_size=784, num_classes=10, hidden_sizes=[512, 256, 256, 128]):
        """
        Inputs:
            act_fn - 激活函数的对象,作为神经网络非线性部分
            input_size - 输入图像尺寸(单位为像素)
            num_classes - 要预测的类别数
            hidden_sizes - 隐藏层神经元数量列表
        """
        super().__init__()

        # 根据指定的隐藏层大小来初始化神经网络
        layers = []
        layer_sizes = [input_size] + hidden_sizes
        for layer_index in range(1, len(layer_sizes)):
            layers += [nn.Linear(layer_sizes[layer_index-1], layer_sizes[layer_index]),
                       act_fn]
        layers += [nn.Linear(layer_sizes[-1], num_classes)]
        self.layers = nn.Sequential(*layers) # nn.Sequential summarizes a list of modules into a single module, applying them in sequence

        # 将所有参数保存为字典格式,便于随后加载模型时使用
        self.config = {"act_fn": act_fn.config, "input_size": input_size, "num_classes": num_classes, "hidden_sizes": hidden_sizes}

    def forward(self, x):
        #将输入展平
        x = x.view(x.size(0), -1) 
        out = self.layers(x)
        return out
{"act_fn": {"name": "Tanh"}, "input_size": 784, "num_classes": 10, "hidden_sizes": [512, 256, 256, 128]}
获取模型配置文件
def _get_config_file(model_path, model_name):
    # 保存参数的文件
    return os.path.join(model_path, model_name + ".config")
获取模型文件
def _get_model_file(model_path, model_name):
    return os.path.join(model_path, model_name + ".tar")
加载模型
def load_model(model_path, model_name, net=None):
    """
    加载保存的模型

    Inputs:
        model_path - 存放模型的 checkpoint 的路径
        model_name - 模型的名称
        net - (可选项) 如果传入模型这将参数加载到该模型,否则新建一个模型
    """
    config_file, model_file = _get_config_file(model_path, model_name), _get_model_file(model_path, model_name)
    assert os.path.isfile(config_file), f"Could not find the config file \"{config_file}\". Are you sure this is the correct path and you have your model config stored here?"
    assert os.path.isfile(model_file), f"Could not find the model file \"{model_file}\". Are you sure this is the correct path and you have your model stored here?"
    with open(config_file, "r") as f:
        config_dict = json.load(f)
        
     #如果没有传入网络模型这加载
    if net is None:
        act_fn_name = config_dict["act_fn"].pop("name").lower()
        act_fn = act_fn_by_name[act_fn_name](**config_dict.pop("act_fn"))
        net = BaseNetwork(act_fn=act_fn, **config_dict)
    net.load_state_dict(torch.load(model_file, map_location=device))
    return net
保存模型
def save_model(model, model_path, model_name):
    """
    对于给定的模型,保存模型状态和超参数
    Inputs:
        model - 网络模型对象Network object to save parameters from
        model_path - 保存 checkpoint 路径
        model_name - 模型的名称
    """
    config_dict = model.config
    os.makedirs(model_path, exist_ok=True)
    config_file, model_file = _get_config_file(model_path, model_name), _get_model_file(model_path, model_name)
    with open(config_file, "w") as f:
        json.dump(config_dict, f)
    torch.save(model.state_dict(), model_file)

如前所述,现在神经网络深度都比较深层,一般都有上百层的网络结构,那么激活函数能够保证梯度在网络中有效地传递就显得非常重要,例如如果网络有 50 多层,我们知道神经网络是一个嵌套函数,那么也就是反向传播梯度时,会经过 50 多次激活函数,还能确保梯度是一个比较合理数是一件比较困难的事。

import torchvision
from torchvision.datasets import FashionMNIST
from torchvision import transforms

首先对图像转为 tensor 然后进行标准化

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

FashionMNIST

我们还设置了我们想要训练它的数据集,即FashionMNIST。FashionMNIST是MNIST的一个更复杂的版本,包含衣服的黑白图像而不是数字。10个类别包括长裤、大衣、鞋子、包等等。为了加载这个数据集,我们将利用另一个PyTorch包,即torchvision(文档)。torchvision包由流行的数据集、模型架构和计算机视觉的常见图像转换组成。我们将在本课程的许多笔记本中使用该包,以简化我们的数据集处理。

下载训练数据集,然后将训练数据集拆分为训练集和验证集 2 个部分

train_dataset = FashionMNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000])

定义测试数据集加载器

test_set = FashionMNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

我们定义一系列数据加载器,这些数据加载用于加载训练数据、验证数据或者测试数据。值得注意的是在实际训练过程会采用小批量的加载器

train_loader = data.DataLoader(train_set, batch_size=1024, shuffle=True, drop_last=False)
val_loader = data.DataLoader(val_set, batch_size=1024, shuffle=False, drop_last=False)
test_loader = data.DataLoader(test_set, batch_size=1024, shuffle=False, drop_last=False)

exmp_imgs = [train_set[i][0] for i in range(16)]
# Organize the images into a grid for nicer visualization
img_grid = torchvision.utils.make_grid(torch.stack(exmp_imgs, dim=0), nrow=4, normalize=True, pad_value=0.5)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8,8))
plt.title("FashionMNIST examples")
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

activation_001.png

def visualize_gradients(net, color="C0"):
    """
    输入:
        net -  BaseNetwork 的对象
        color 
    """
    # 将网络设置为评估,也就是不会更新参数
    net.eval()
    # 加载小批量,数量 256 张图像
    small_loader = data.DataLoader(train_set,batch_size=256,shuffle=False)
    # 进行一次迭代
    imgs, labels = next(iter(small_loader))
    # 如果使用 GPU 切换一下,暂时用的 CPU
    imgs, labels = imgs.to(device), labels.to(device)
    
    #将一个批次图像输入到网络中,然后计算网络权重的梯度
    # 将梯度初始化清空
    net.zero_grad()
    
    preds = net(imgs)
    #采用交叉熵
    loss = F.cross_entropy(preds,labels)
    #反向传播,计算参数的梯度
    loss.backward()
    # 仅查看名称为 weight 的参数的梯度
    grads ={name:params.grad.data.view(-1).cpu().clone().numpy() for name, params in net.named_parameters() if "weight" in name}
    net.zero_grad()
import warnings
warnings.filterwarnings('ignore')
act_fn = act_fn_by_name['sigmoid']()
net_actfn = BaseNetwork(act_fn=act_fn).to(device)
visualize_gradients(net_actfn)

activation_002.png

回顾一下网络结构分别为 512 256 256 128 10 这里时 Sigmoid 激活函数,明显不是很理想,虽然输出层梯度比较大,可以达到 0.1 不过当梯队回传到输入层时候,只有 1e-5

act_fn = act_fn_by_name['tanh']()
net_actfn = BaseNetwork(act_fn=act_fn).to(device)
visualize_gradients(net_actfn)

activation_003.png

# for i,act_fn_name in enumerate(act_fn_by_name):
#     set_seed(42)
act_fn = act_fn_by_name['relu']()
net_actfn = BaseNetwork(act_fn=act_fn).to(device)
visualize_gradients(net_actfn)

activation_005.png

import seaborn as sns
def visualize_gradients(net, color="C0"):
    """
    输入:
        net -  BaseNetwork 的对象
        color 
    """
    # 将网络设置为评估,也就是不会更新参数
    net.eval()
    # 加载小批量,数量 256 张图像
    small_loader = data.DataLoader(train_set,batch_size=256,shuffle=False)
    # 进行一次迭代
    imgs, labels = next(iter(small_loader))
    # 如果使用 GPU 切换一下,暂时用的 CPU
    imgs, labels = imgs.to(device), labels.to(device)
    
    #将一个批次图像输入到网络中,然后计算网络权重的梯度
    # 将梯度初始化清空
    net.zero_grad()
    print(imgs.shape)
    preds = net(imgs)
    #采用交叉熵
    loss = F.cross_entropy(preds,labels)
    #反向传播,计算参数的梯度
    loss.backward()
    # 仅查看名称为 weight 的参数的梯度
    grads ={name:params.grad.data.view(-1).cpu().clone().numpy() for name, params in net.named_parameters() if "weight" in name}
    net.zero_grad()
#     print(grads)
    #绘制
    
    
    columns = len(grads)
    fig, ax = plt.subplots(1, columns, figsize=(columns *3.5,2.5))
    fig_index = 0
    for key in grads:
        key_ax = ax[fig_index%columns]
        sns.histplot(data=grads[key],bins=30,ax=key_ax,color=color,kde=True)
        key_ax.set_title(str(key))
        key_ax.set_xlabel("Grad magnitute")
        fig_index += 1
        
    fig.suptitle(f"Gradient magnitude distribution for activation function {net.config['act_fn']['name']}",fontsize=14,y=1.05)
    fig.subplots_adjust(wspace=0.45)
    plt.show()
    plt.close()

训练模型

接下来,要在 FashionMNIST 上用不同激活函数训练模型,进行横向比较性能。总而言之,我们先定义训练模型的函数

def test_model(net,data_loader):
    net.eval()
    true_preds, count = 0.,0
    for imgs, labels in data_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.no_grad():
            preds = net(imgs).argmax(dim=-1)
            true_preds += (preds == labels).sum().item()
            count += labels.shape[0]
    test_acc = true_preds / count
    return test_acc
def train_model(net, model_name,max_epochs=50, patience=7,batch_size=256,overwrite=False):
    """
    net 模型对象
    model_name 模型名称
    max_epochs 训练的最大迭代数
    patience:在验证集上经历多少迭代没有改善就停止训练
    overwirte 是否覆盖已经存在的 checkpoint
    """
    
    file_exists = os.path.isfile(_get_model_file(CHECKPOINT_PATH,model_name))
    
    if file_exists and not overwrite:
        print("Model file already exists. Skipping training...")
    else:
        if file_exists:
            print("Model file exists, but will be overwritten...")
        #采用 SGD 优化器
        optimizer = optim.SGD(net.paramters(),lr=1e-2,momentum=0.9)
        #损失函数采用交叉熵
        loss_module = nn.CrossEntropyLoss()
        #加载数据
        train_loader_local = data.DataLoader(train_set,batch_size=batch_size,shuffle=True,drop_last=True,pin_memory=False)
        #验证集分数
        val_scores = []
        beat_val_epoch = -1
        #训练网络
        for epoch in range(max_epochs):
            # Training
            net.train()
            true_preds, count =0.0
            for imgs, labels in tqdm(train_loader_local,desc=f"Epoch {epoch + 1}",leave=False):
                imgs, labels = imgs.to(device), labels.to(device)
                #将梯度清空
                optimizer.zero_grad()
                #预测图像
                preds = net(imgs)
                loss = loss_module(preds,labels)
                loss.backward()
                #更新参数
                optimizer.step()
                #统计批次样本中正确分类的样本数
                true_preds += (preds.argmax(dim=-1) == labels).sum()
                #统计样本数
                count += labels.shape[0]
            #计算准确率
            train_acc = true_preds / count
            #测试网络
            val_acc = test_model(net,val_loader)
            val_scores.append(val_acc)
            
            print(f"[Epoch {epoch + 1:2d}] Training accuracy: {train_acc*100.0:0.05.2f}%, Validation accuracy:{val_acc*100.0:0.05.2f}%")
            #判断验证集精度以及当前验证集上准确度是否大于
            if len(val_scores) == 1 or val_acc > val_scores[best_val_epoch]:
                print("\t (New best performance saving model...)")
                save_model(net,CHECKPOINT_PATH,model_name)
                best_val_epoch = epoch
            elif best_val_epoch <= epoch - patience:
                print(f"Early stopping due to no improvement over the last {patience} epochs")
                break
        #绘制曲线
        plt.plot([i for i in range(1,len(val_scores) + 1)], val_scores)
        plt.xlabel("Epochs")
        plt.ylabel("Validation accuracy")
        plt.title(f"Validation performance of {model_name}")
        plt.show()
        plt.close()
    load_model(CHECKPOINT_PATH,model_name,net=net)
    test_acc = test_model(net,test_loader)
    print((f"Test accuracy:{test_acc*100.0:4.2f}% ").center(50,"=") + "\n")
for act_fn_name in act_fn_by_name:
    print(f"Training BaseNetwork with {act_fn_name} activation...")
    set_seed(42)
    act_fn = act_fn_by_name[act_fn_name]()
    net_actfn = BaseNetwork(act_fn=act_fn).to(device)
    train_model(net_actfn,f"FashionMNIST_{act_fn_name}",overwrite=False)
Training BaseNetwork with sigmoid activation...
Model file already exists. Skipping training...
==============Test accuracy:10.00% ===============

Training BaseNetwork with tanh activation...
Model file already exists. Skipping training...
==============Test accuracy:87.59% ===============

Training BaseNetwork with relu activation...
Model file already exists. Skipping training...
==============Test accuracy:88.62% ===============

Training BaseNetwork with leakyrelu activation...
Model file already exists. Skipping training...
==============Test accuracy:88.92% ===============

Training BaseNetwork with elu activation...
Model file already exists. Skipping training...
==============Test accuracy:87.27% ===============

Training BaseNetwork with swish activation...
Model file already exists. Skipping training...
==============Test accuracy:88.73% ===============

从结果上来看,采用 Sigmoid 激活函数的模型,10 预测结果为 10% 基本和猜没有什么区别,也就是说明没有学到东西,从采用其他激活函数的结果来看,上下相差的并不多,为了得到更准确的结果,需要进行多次实验,然后看平均值,这样才更有说服力。

想要得到一个好网络,不仅仅看激活函数,还需要看其他因素,例如隐藏层的大小、网络深度、隐藏层的类型、任务、数据集、优化器和学习率等等