Pytorch model.train()、model.eval()
测试模型时前面加:model.eval()
。
但是不写这两个方法,模型也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如 \(Batch\ Normalization、Dropout\)。
\(Dropout\):在训练过程的前向传播中,让每个神经元以一定的概率 \(p\) 处于不激活状态。以达到减少过拟合的效果。
-
训练时针对每个 \(min-batch\) 的,但是在测试中往往是针对单张图片,即不存在 \(min-batch\) 的概念。
-
由于网络训练完毕后参数都是固定的,因此每个批次的均值和方差都是不变的,因此直接结算所有 \(batch\) 的均值和方差。
-
所有 \(Batch\ Normalization\) 的训练和测试时的操作不同。
-
在训练中,每个隐层的神经元先乘概率 \(P\),然后在进行激活,在测试中,所有神经元先进行激活,然后每个隐层神经元的输出乘 \(P\)。
补充:\(Pytorch\) 踩坑记录——model.eval()
最近在写代码时遇到一个问题,原本训练好的模型,加载进来进行 \(inference\) 准确率直接掉了 \(5\) 个点。
对训练好的模型加载进来准确率与原先的不符,常见两方面原因:
① \(data\)。
检查前后两次加载的 \(data\) 有无变化。
首先检查 \(transforms.Normalize\) 使用的均值和方差是否和训练时相同;
另外检查在这个过程中数据是否经过存储形式的改变,这有可能会带来数据精度的变化导致一定信息的丢失。比如一个数据集,原先将图片存储成向量形式,但其对应的是 “\(png\)” 格式的数据,而后进行了一次 \(data-to-img\) 操作,将向量转换成了 “\(jpg\)” 形式,这时加载进来便造成了掉点。
② \(model.state\_dict()\)。
上面的掉点一般不会太严重,第二方面造成的掉点就比较严重了,一旦模型的参数加载错了,那误差会很大。
如果是参数没有正确加载进来则比较容易发现,这时的准确率非常低,几乎等于瞎猜。
而我这次遇到的情况是,准确率并不是很低,只是掉了几个点,检查了多次,均显示模型参数已经成功加载了。后来仔细查看后发现在其中一次调用模型进行 \(inference\) 时,忘了写 model.eval()
,造成了模型参数发生变化,再次调用则出现了掉点。
于是回顾了一下 model.eval()
和 model.train()
的具体作用。如下:
model.train()
和model.eval()
一般在模型训练和评价时会加上这两句,主要是针对由于 \(model\) 在训练时和评价时 \(Batch\)
\(Normalization\) 和 \(Dropout\) 方法模型不同:
model.eval()
:不启用 \(BatchNormalization\) 和 \(Dropout\)。此时 \(pytorch\) 会自动把 \(BN\) 和 \(DropOut\) 固定住,不会取平均,而是用训练好的值。不然的话,一旦 \(test\) 的 \(batch\_size\) 过小,很容易就会因 \(BN\) 层导致模型 性能损失较大;
model.train()
:启用 \(BatchNormalization\) 和 \(Dropout\)。在模型测试阶段使用 model.train()
让模型变成训练模式,此时 \(dropout\) 和 \(batch\ normalization\) 的操作在训练起到防止网络过拟合的问题。
因此,在使用 \(PyTorch\) 进行训练和测试时一定要记得把实例化的 \(model\) 指定 \(train/eval\)。
\(model.eval()\) 与 \(troch.no\_grad()\) 的比较:
虽然两者都是 \(eval\) 的时候使用,但其作用并不相同。
model.eval()
负责改变 \(batchnorm、dropout\) 的工作方式,如在 \(eval()\) 模式下,\(dropout\) 是不工作的。见下方代码:
import torch
import torch.nn as nn
drop = nn.Dropout()
x = torch.ones(10)
drop.train() # 模型训练
print(drop(x)) # tensor([2., 2., 0., 2., 2., 2., 2., 0., 0., 2.])
drop.eval() # 模型评估
print(drop(x)) # tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
torch.no_grad()
负责关掉梯度计算,节省 \(eval\) 的时间。
只进行 \(inference\) 时,model.eval()
是必须使用的,否则会影响结果准确性。而 torch.no_grad()
并不是强制的,只影响运行效率。