深度学习-梯度爆炸原因分析、调试记录与解决方案(loss突然变为nan)

477 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路

结论:

该项目是由于脏数据造成的梯度爆炸的,剔除就好了。

1. 程序崩溃现场

在这里插入图片描述 这里出现了,函数返回了nan,导致后边无法进行梯度更新。

2. 原因分析与调试记录

2.1 定位错误出现位置

因为提示信息不明显,所以我需要把详细的数据打印出来,让它报错,看具体哪里出错的。 在这里插入图片描述 (我的dataloader,数据打乱设置的False,所以数据出错的位置是不会变的)可以看到在1087处出现的错误,但打印的是i+1控制的,所以是i=1086处出现了错误。

2.2 打上断点进行分析

要是逐个再运行到1086处太慢了,于是前1085个直接continue过去,然后在第一个计算loss的位置打上断点,对主程序进行debug 在这里插入图片描述

2.3 debug进入损失函数运行的地方,找原因

在这里插入图片描述 发现,确实有一个数据不正常,area[2] = 0.0,它在损失函数中充当的是分母。 于是定位是数据处理的有问题,于是找到计算area的源数据,将其可视化出来。发现是黑乎乎一片。我在数据集里找到了那个文件,发现确实,是数据加载的有问题。 现在定位出问题是脏数据,那么数据集里很可能还有别的脏数据。接下来目标变为,处理脏数据。 在这里插入图片描述

2.3 处理脏数据

找到加载数据的地方,注释掉原来写的__getitem__, 因为刚才主要是mask出现了问题,那么现在只处理与mask相关的数据,其他的删掉,提高效率。area是有mask算出来的,且脏数据的特征是area为0.0,那么遇到它将数据的id打印出来。 在这里插入图片描述 现在找到了出问题的数据。 那么将其剔除,删除太麻烦,因为与其关联的数据可不止这3个出问题的数据。其实改变思路,遇到脏数据不加载即可。 在这里插入图片描述 遇到它们的id跳过即可。 现在,程序以及可以正常运行了。 在这里插入图片描述