代码笔记10 PyTorch中如何减少训练中模型的显存占用

1

  由于之前一直没有用过torch中自带的可视化工具 torchsummary,就安装了试一下
好像用conda无法安装,只能用pip安装。
进入相应的conda虚拟环境输入

pip install torchsummary

注意别忘了关vpn或者代理之类的,不然会报错

2

  起因是这样,安装了后,通过可视化发现,自己写的模型所占的显存和模型大小明显比github上的大很多,还不是一般的大一点,是很多。我说我怎么老是显卡不够用,原来是我写的有问题,于是我探究了一下究竟哪些代码会有影响。

第一个就是Module类引用Module类的原因。
由于当年开始学Pytorch的时候,有个教程说,把双层或者三层卷积写成一个单独的Module类,然后再在网络架构中再引入。这样写确实方便,但真的很占显存。。。
以下是示例

import torch.nn as nn
import torch
from torchsummary import summary


class doubleconv1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
    def forward(self,x):
        x = x.to("cpu")
        c1 = self.conv1(x)
        c2 = self.conv2(c1)

        return c2

class doubleconv2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
    def forward(self,x):
        x = x.to("cpu")
        c1 = self.conv1(x)
        c2 = self.conv2(c1)

        return c2

class net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)


    def forward(self,x):
        x = x.to("cpu")
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        c4 = self.conv4(c3)

        return c4

class net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = doubleconv1()
        self.conv2 = doubleconv2()


    def forward(self, x):
        x = x.to("cpu")
        c1 = self.conv1(x)
        c2 = self.conv2(c1)


        return c2


if __name__ == "__main__":
    net11 = net1()
    net22 = net2()
    summary(net11, (3, 512, 704))
    summary(net22, (3, 512, 704))

然后可以看到结果
net11的参数是

Total params: 112,576
Trainable params: 112,576
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.12
Forward/backward pass size (MB): 704.00
Params size (MB): 0.43
Estimated Total Size (MB): 708.55
----------------------------------------------------------------

net22的参数是

Total params: 112,576
Trainable params: 112,576
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.12
Forward/backward pass size (MB): 1056.00
Params size (MB): 0.43
Estimated Total Size (MB): 1060.55
----------------------------------------------------------------

相同的参数量,整整多了300Mb。至于究竟是什么原因,我还得进一步学习,现阶段也不太明白。

3

第二个原因就是池化层和激活层,其实可以通过torch.nn.functional来实现,而不需要非得再网络架构中建立这些层,也会占显存

import torch.nn as nn
import torch
from torchsummary import summary
import torch.nn.functional as F


class doubleconv1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
    def forward(self,x):
        x = x.to("cpu")
        c1 = self.conv1(x)
        c2 = self.conv2(c1)

        return c2

class doubleconv2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
    def forward(self,x):
        x = x.to("cpu")
        c1 = self.conv1(x)
        c2 = self.conv2(c1)

        return c2

class net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)



    def forward(self,x):
        x = x.to("cpu")
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        c4 = self.conv4(c3)

        return c4

class net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = doubleconv1()
        self.conv2 = doubleconv2()


    def forward(self, x):
        x = x.to("cpu")
        c1 = self.conv1(x)
        c2 = self.conv2(c1)


        return c2

class net3(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.ac1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.ac2 = nn.ReLU()
        self.conv3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.ac3 = nn.ReLU()
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.ac4 = nn.ReLU()


    def forward(self,x):
        x = x.to("cpu")
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        a1 = self.ac1(p1)
        c2 = self.conv2(a1)
        p2 = self.pool2(c2)
        a2 = self.ac2(p2)
        c3 = self.conv3(a2)
        p3 = self.pool3(c3)
        a3 = self.ac3(p3)
        c4 = self.conv4(a3)
        p4 = self.pool4(c4)
        a4 = self.ac4(p4)

        return a4

class net4(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)

        self.conv3 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)

        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)



    def forward(self,x):
        x = x.to("cpu")
        c1 =  F.relu(F.max_pool2d(self.conv1(x), kernel_size=2, stride=2))

        c2 = F.relu(F.max_pool2d(self.conv2(c1), kernel_size=2, stride=2))

        c3 = F.relu(F.max_pool2d(self.conv3(c2), kernel_size=2, stride=2))

        c4 = F.relu(F.max_pool2d(self.conv4(c3), kernel_size=2, stride=2))


        return c4

if __name__ == "__main__":
    net11 = net1()
    net22 = net2()
    net33 = net3()
    net44 = net4()
    summary(net44, (3, 512, 704))

其中net33的输出为

Total params: 112,576
Trainable params: 112,576
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.12
Forward/backward pass size (MB): 350.62
Params size (MB): 0.43
Estimated Total Size (MB): 355.18
----------------------------------------------------------------

net44的输出为

Total params: 112,576
Trainable params: 112,576
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.12
Forward/backward pass size (MB): 233.75
Params size (MB): 0.43
Estimated Total Size (MB): 238.30


这样也同样可以减少占用空间

posted @ 2022-05-25 20:51  The1912  阅读(707)  评论(0编辑  收藏  举报