Pytorch model.eval()的作用

671 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路

使用pytorch训练和预测时会分别使用到以下两行代码:

model.train() 
model.eval()

后来想了解model.eval()的具体作用,在网上查找资料大都是以下原因: 模型中有BatchNormalization和Dropout,在预测时使用model.eval()后会将其关闭以免影响预测结果。

但是没有找到BN和Dropout是具体如何影响预测结果的,直到看到这篇博客中的内容才有所理解,个人理解如下:

1)训练过程中BN的变化。

在训练过程中BN会不断的依据训练数据计算均值和方差,训练结束后得到最终的均值和方差,在此处将其记为mean_train,variance_train。

2)预测过程中BN的变化。

如果使用model.eval()则BN层就会直接使用训练过程中得到的均值和方差来对测试数据进行预测,此时能够保证预测结果不受影响。

预测过程中如果不使用model.eval()的话,BN层还是会根据输入的预测数据继续计算均值和方差,假设输入一条预测数据后,BN层计算得到其均值和方差分别为mean_test,variance_test,此时BN层的均值和方差则变成了(mean_train+mean_test),(variance_train+variance_test),相比于训练过程中的均值和方差发生了变化因此会导致预测结果发生变化。

3)训练过程中Dropout的变化

训练过程中依据设置的dropout比例会使一部分的网络连接不进行计算。

4)预测过程中Dropout的变化

预测过程中如果不使用model.eval()的话,依然会使一部分的网络连接不进行计算,而使用model.eval()后就是所有的网络连接均进行计算。