【PyTorch】Module

Model

children

 

parameters

modules

state_dict

Container

 1 import torch
 2 import torch.nn as nn
 3 from torchvision.models.resnet import (_resnet, Bottleneck)
 4 
 5 
 6 class model_with_container(nn.Module):
 7     def __init__(self, type):
 8         super(model_with_container, self).__init__()
 9         self.type = type
10         if self.type == 'ModuleList':
11             self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(5)])
12         elif self.type == 'Sequential':
13             self.seq = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
14 
15     def forward(self, x):
16         if self.type == 'ModuleList':
17             for idx in [2, 3, 4, 0, 1]:
18                 x = self.linears[idx](x)
19         elif self.type == 'Sequential':
20             x = self.seq(x)
21         return self.type, x
22 
23 
24 def creat_model():
25     model = _resnet('resnet50', Bottleneck, [2, 0, 0, 0], False, False)
26     model.layer1[0].__delattr__('conv2')
27     model.layer1[0].__delattr__('bn2')
28     model.layer1[0].__delattr__('conv3')
29     model.layer1[0].__delattr__('bn3')
30     model.layer1[1].__delattr__('conv2')
31     model.layer1[1].__delattr__('bn2')
32     model.layer1[1].__delattr__('conv3')
33     model.layer1[1].__delattr__('bn3')
34     model.__delattr__('layer2')
35     model.__delattr__('layer3')
36     model.__delattr__('layer4')
37     return model
38 
39 
40 def test_model():
41     model = creat_model()
42     for names_and_children, children in zip(model.named_children(), model.children()):
43         i, j = names_and_children
44         k = children
45         print(i, id(j) == id(k))
46 
47     for names_and_parameters, parameters in zip(model.named_parameters(), model.parameters()):
48         i, j = names_and_parameters
49         k = parameters
50         print(i, id(j) == id(k))
51 
52     for names_and_modules, modules in zip(model.named_modules(), model.modules()):
53         i, j = names_and_modules
54         k = modules
55         print(i, id(j) == id(k))
56 
57     for name, parameter in model.state_dict().items():
58         print(name)
59 
60 
61 def test_container():
62     x = torch.randn([2, 10])
63     model = model_with_container('ModuleList')
64     # model = model_with_container('Sequential')
65     print(model, model(x), sep='\n')
66 
67 
68 if __name__ == '__main__':
69     test_model()
70     test_container()
posted @ 2022-03-10 18:51  Vivid-BinGo  阅读(48)  评论(0编辑  收藏  举报