pytorch笔记之正确的inference

380 阅读1分钟

在非训练状态的预测过程中要注意下面两点

  • model.eval():令 model 中的BatchNorm, Dropout等 module 采用 eval mode,保证 inference 结果的正确性。也就说舍弃掉中的BatchNorm和Dropout等加速训练的步骤,充分利用模型中的既有参数,让预测结果具有唯一性
  • torch.no_grad():声明不计算梯度,节省大量内存和显存。模型中正常调用的函数都会默认加上autograd功能,在反向传播时自动计算梯度,而这会消耗一部分计算资源和内存空间,同时降低预测效率,使用torch.no_grad()则表示取消计算梯度这一步骤。使用方法一般分为两种,一种是当做装饰器在撰写预测函数时使用,另一种是with torch.no_grad():的方式将模型预测过程放置在其中使用