一文讲清楚人工智能正则化--早停法

603 阅读6分钟

前言

在前面一篇文章中,笔者详细讲解了人工智能深度学习中常用的正则化方法--Dropout正则化技术,但是大家有没有发现,我们前几种的正则化方法都是在模型的损失函数上稍作改动来达到预防过拟合的效果,那么有没有一种方法可以不改变模型的损失函数或者是结构的同时又能够预防模型过拟合呢 ? 那就是这篇文章中我们要讲的--早停法。

一、定义与原理

1. 定义

早停(Early Stopping)是一种有效的正则化技术,用于防止机器学习和深度学习模型在训练过程中出现过拟合。其核心思想是在模型训练过程中,通过监控模型在验证集上的性能来决定是否提前结束训练。

2. 早停法的原理图

1212a14f8929c463704c17ec86dd0542.png

3. 早停法的工作原理(结合上图更直观)

  1. 性能监控:在每个训练周期(epoch)结束时,评估模型在独立的验证集上的性能,通常是使用损失函数(如交叉熵损失或均方误差)来衡量。
  2. 性能比较:将当前周期的验证集性能与之前最佳性能(例如,最低的损失)进行比较。
  3. 更新最佳模型:如果当前周期的性能更好,则更新记录的最佳性能,并保存当前模型参数。
  4. 早停条件:如果在连续多个周期(由耐心参数patience控制)内,模型在验证集上的性能没有进一步改善,则停止训练。(其实这个早停的条件大家可以根据自己的业务或者项目的需求而自己设定的)

早停法在神经网络中的意义:

这里笔者看了一个UP主的文章,感觉他这一部分讲的很好,笔者引用一段他的讲解。

为了获得性能良好的神经网络,网络定型过程中需要进行许多关于所用设置(超参数)的决策。超参数之一是定型周期(epoch)的数量:亦即应当完整遍历数据集多少次(一次为一个epoch)如果epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。早停法旨在解决epoch数量需要手动设置的问题

如果对于这个工作原理还是不理解的,下面笔者会给出自己的一个模型的训练日志做例子:

image.png

详解:

大家看笔者该模型的这部分训练日志容易发现,其实在 epoch 为 7 ,step 为1100的时候模型的总损失值此时已经为该部分的最低点了,后面模型的损失值随着训练轮次 epoch 的增加不降反增也可以看出对于该模型来说,epoch 7 的时候其实已经拟合到了一个不错的效果了,此时如果再继续训练下去,模型就会存在过拟合风险,模型的效果恐怕会不如之前,甚者可能大打折扣,此时我们的早停法就派上用场了,可以在一定条件满足时直接停止模型的训练,减少过拟合风险,使模型停留在一个效果较好的水平上。

二、早停法的应用

1. 优点

  • 防止过拟合:通过提前结束训练,避免模型在训练数据上过度拟合。
  • 节省资源:减少不必要的训练时间,节省计算资源。
  • 提高泛化能力:选择在验证集上表现最佳的模型参数,有助于提高模型在未知数据上的表现。

2. 怎么使用早停法(步骤)

  • 初始化:设置最佳验证损失(或其他指标)为一个很大的值,初始化耐心计数器。
  • 训练循环
    • 在每个epoch结束后,计算验证集的损失。
    • 如果当前验证损失小于最佳损失:
      • 更新最佳损失。
      • 重置耐心计数器。
      • 保存当前模型参数。
    • 如果当前验证损失大于最佳损失:
      • 增加耐心计数器。
      • 如果耐心计数器超过设定值,则停止训练。

3.注意事项

  • 选择合适的耐心值:耐心值应根据具体任务和数据集调整,太小可能导致过早停止,太大可能导致训练时间过长。
  • 与其他正则化方法结合:早停法可以与权重衰减、Dropout等其他正则化技术结合使用,以进一步提高模型的泛化能力。
  • 使用交叉验证:在某些情况下,使用交叉验证来评估早停法的效果可能更准确,因为它可以提供更稳健的性能估计。

总的来说,早停法是一种简单而有效的技术,并被广泛应用于各种机器学习和深度学习模型的训练过程中,以提高模型的泛化性能并减少过拟合风险。但是大家还是要慎用,因为有些时候,当模型陷于局部最优值的假象时,早停法反而可能会把模型的性能拉低。

三、示例代码辅助理解

以下是一个早停法的简单示例,使用的是Python和Scikit-learn库:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV
from sklearn.early_stopping import EarlyStoppingClassifier

# 创建一个模拟数据集
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)

# 划分训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化随机森林分类器
rf = RandomForestClassifier(random_state=42)

# 设置早停法的参数
# n_iter_no_change: 验证集的性能在多少个迭代后没有改善时停止训练
# 这里我们设置为5
# 验证集:用于早停法的验证数据
# 这里我们使用X_val, y_val
estimator = EarlyStoppingClassifier(rf, n_iter_no_change=5, validation_split=0.2)

# 训练模型
estimator.fit(X_train, y_train)

# 评估模型
y_pred = estimator.predict(X_val)
accuracy = accuracy_score(y_val, y_pred)
print(f"Accuracy: {accuracy:.2f}")

# 打印早停法的相关信息
print(f"Training for {estimator.n_iter_} iterations.")
print(f"Best score: {estimator.best_score_:.2f}")

这个示例中我先生成了一个模拟的分类数据集,然后将其划分为训练集和验证集。我使用EarlyStoppingClassifier包装了随机森林分类器,并设置了早停的参数。n_iter_no_change参数定义了在验证集的性能在多少个迭代后没有改善时停止训练。validation_split参数定义了从训练数据中划分出多少比例的数据用于早停法的验证。训练完成后,我评估了模型在验证集上的准确率,并打印了早停法的一些相关信息,比如实际训练的迭代次数和最佳验证集得分。

四、Reference

blog.csdn.net/weixin_4051…

以上就是笔者关于人工智能正则化技术--早停法的讲解,欢迎大家点赞,收藏,交流和关注,O(∩_∩)O谢谢!