代码笔记10 PyTorch中如何减少训练中模型的显存占用
1
由于之前一直没有用过torch中自带的可视化工具 torchsummary,就安装了试一下
好像用conda无法安装,只能用pip安装。
进入相应的conda虚拟环境输入
pip install torchsummary
注意别忘了关vpn或者代理之类的,不然会报错
2
起因是这样,安装了后,通过可视化发现,自己写的模型所占的显存和模型大小明显比github上的大很多,还不是一般的大一点,是很多。我说我怎么老是显卡不够用,原来是我写的有问题,于是我探究了一下究竟哪些代码会有影响。
第一个就是Module类引用Module类的原因。
由于当年开始学Pytorch的时候,有个教程说,把双层或者三层卷积写成一个单独的Module类,然后再在网络架构中再引入。这样写确实方便,但真的很占显存。。。
以下是示例
import torch.nn as nn
import torch
from torchsummary import summary
class doubleconv1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
def forward(self,x):
x = x.to("cpu")
c1 = self.conv1(x)
c2 = self.conv2(c1)
return c2
class doubleconv2(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
def forward(self,x):
x = x.to("cpu")
c1 = self.conv1(x)
c2 = self.conv2(c1)
return c2
class net1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
def forward(self,x):
x = x.to("cpu")
c1 = self.conv1(x)
c2 = self.conv2(c1)
c3 = self.conv3(c2)
c4 = self.conv4(c3)
return c4
class net2(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = doubleconv1()
self.conv2 = doubleconv2()
def forward(self, x):
x = x.to("cpu")
c1 = self.conv1(x)
c2 = self.conv2(c1)
return c2
if __name__ == "__main__":
net11 = net1()
net22 = net2()
summary(net11, (3, 512, 704))
summary(net22, (3, 512, 704))
然后可以看到结果
net11的参数是
Total params: 112,576
Trainable params: 112,576
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.12
Forward/backward pass size (MB): 704.00
Params size (MB): 0.43
Estimated Total Size (MB): 708.55
----------------------------------------------------------------
net22的参数是
Total params: 112,576
Trainable params: 112,576
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.12
Forward/backward pass size (MB): 1056.00
Params size (MB): 0.43
Estimated Total Size (MB): 1060.55
----------------------------------------------------------------
相同的参数量,整整多了300Mb。至于究竟是什么原因,我还得进一步学习,现阶段也不太明白。
3
第二个原因就是池化层和激活层,其实可以通过torch.nn.functional来实现,而不需要非得再网络架构中建立这些层,也会占显存
import torch.nn as nn
import torch
from torchsummary import summary
import torch.nn.functional as F
class doubleconv1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
def forward(self,x):
x = x.to("cpu")
c1 = self.conv1(x)
c2 = self.conv2(c1)
return c2
class doubleconv2(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
def forward(self,x):
x = x.to("cpu")
c1 = self.conv1(x)
c2 = self.conv2(c1)
return c2
class net1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
def forward(self,x):
x = x.to("cpu")
c1 = self.conv1(x)
c2 = self.conv2(c1)
c3 = self.conv3(c2)
c4 = self.conv4(c3)
return c4
class net2(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = doubleconv1()
self.conv2 = doubleconv2()
def forward(self, x):
x = x.to("cpu")
c1 = self.conv1(x)
c2 = self.conv2(c1)
return c2
class net3(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.ac1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.ac2 = nn.ReLU()
self.conv3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.ac3 = nn.ReLU()
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.ac4 = nn.ReLU()
def forward(self,x):
x = x.to("cpu")
c1 = self.conv1(x)
p1 = self.pool1(c1)
a1 = self.ac1(p1)
c2 = self.conv2(a1)
p2 = self.pool2(c2)
a2 = self.ac2(p2)
c3 = self.conv3(a2)
p3 = self.pool3(c3)
a3 = self.ac3(p3)
c4 = self.conv4(a3)
p4 = self.pool4(c4)
a4 = self.ac4(p4)
return a4
class net4(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
def forward(self,x):
x = x.to("cpu")
c1 = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2, stride=2))
c2 = F.relu(F.max_pool2d(self.conv2(c1), kernel_size=2, stride=2))
c3 = F.relu(F.max_pool2d(self.conv3(c2), kernel_size=2, stride=2))
c4 = F.relu(F.max_pool2d(self.conv4(c3), kernel_size=2, stride=2))
return c4
if __name__ == "__main__":
net11 = net1()
net22 = net2()
net33 = net3()
net44 = net4()
summary(net44, (3, 512, 704))
其中net33的输出为
Total params: 112,576
Trainable params: 112,576
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.12
Forward/backward pass size (MB): 350.62
Params size (MB): 0.43
Estimated Total Size (MB): 355.18
----------------------------------------------------------------
net44的输出为
Total params: 112,576
Trainable params: 112,576
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.12
Forward/backward pass size (MB): 233.75
Params size (MB): 0.43
Estimated Total Size (MB): 238.30
这样也同样可以减少占用空间