Pytorch——dropout的理解和使用
在训练CNN网络的时候,常常会使用dropout来使得模型具有更好的泛化性,并防止过拟合。而dropout的实质则是以一定概率使得输入网络的数据某些维度上变为0,这样可以使得模型训练更加有效。但是我们需要注意dropout层在训练和测试的时候,模型架构是不同的。为什么会产生这种差别呢?
一、训练和测试的不同
标准的网络连接如下图:
增加了dropout层的网络结构如下图:
此处的是服从参数为p的伯努利分布的随机数(0或者1),用一句话来说就是y这个向量的每一个维度都会以概率p变为0。
问题来了,在训练的时候我们是有dropout层的,那我们测试的时候还需要么?答案是不需要dropout层了,而是直接把y输入进网络。这时候问题又来了,如果直接把y输入进网络,在训练的时候y由于经历了dropout层,意味着y的数据分布的期望应该是会乘以p的(举个例子如果y的每一个维度都是1,那么经过dropout层之后,有些维度变成了0,那此时真正进入网络的数值分布其期望应该是p),而测试时没有经过dropout层那就意味着训练时输入和测试时输入的期望是不同的,那这个训练好的权重将无法达到最优秀的状态。那该如何处理才能保证这个期望的一致性呢?答案就是在测试时,将网络的权重全部乘以p。如下图:
这就是含有dropout层的训练与测试不同的地方。
二、Pytorch实现dropout
前文讲了dropout训练和测试时的不同,接下来我们讲讲如何在Pytorch建模时使用这个层。
pytorch实现dropout的方式主要有两个,第一个是F.dropout(out, p=0.5),第二个是nn.Dropout(p=0.5),这两者的区别其实就是F和nn的区别。第一个是一个函数,第二个是一个nn.model类。那在实际使用中我们该使用什么呢?在构建网络时我们该使用第二个,因为前面说到了dropout在训练和测试时是不同的,因此我们的测试时使用model.eval()时,如果只使用的是F.dropout(out, p=0.5)那么不会有变化,但是nn.Dropout(p=0.5)已经被模型注册成了nn.model类,因此此时会改变权重。
如果非要使用F.dropout(out, p=0.5),那可以增加一个参数,根据文档显示:
也就是out = nn.functional.dropout(out, p=0.5, training=self.training)。
接下来是推荐代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | class MyNet(nn.Module): def __init__( self , input_size, num_classes): super (MyNet, self ).__init__() self .fc1 = nn.Linear(input_size, num_classes) # 输入层到输出层 self .dropout = nn.Dropout(p = 0.5 ) # dropout训练 def forward( self , x): out = self .dropout(x) print ( 'dropout层的输出:' , out) out = self .fc1(out) return out input_size = 10 num_classes = 5 model = MyNet(input_size, num_classes) |
大家可以使用以下代码进行测试:
1 2 3 4 5 6 | x = torch.arange( 0 , 10 ).reshape( - 1 ). float () print ( '输入向量' , x) model.train() print ( "训练模式下:" , model(x)) model. eval () print ( "测试模式下:" , model(x)) |
出现问题了,画红框部分应该是训练模式下dropout层的输出,确实有些维度变成了0,但是怎么会出现8,10,仔细一看可以发现,dropout在屏蔽一些维度的数值同时,会将没有屏蔽的数值进行调整(缩放)乘以,此处的p=0.5,因此代入公式可以知道,没有被屏蔽的数值应该会被乘以2。这是为什么呢?前文提到了dropout由于会改变输入数据的均值,所以需要对权重进行改变,其实还有一种方式,在屏蔽一些数值的时候同时对其他没有屏蔽的数值进行缩放使其的缩放后的均值依然保持原来均值,这样在测试的时候就不用调整权重了。这就是nn.Dropout的方式(其实F.dropout也是进行了这个样的rescaled)。
参考网页:
理解dropout_张雨石的博客-CSDN博客_dropout
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性