【PyTorch】Module
Model
children
parameters
modules
state_dict
Container
1 import torch
2 import torch.nn as nn
3 from torchvision.models.resnet import (_resnet, Bottleneck)
4
5
6 class model_with_container(nn.Module):
7 def __init__(self, type):
8 super(model_with_container, self).__init__()
9 self.type = type
10 if self.type == 'ModuleList':
11 self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(5)])
12 elif self.type == 'Sequential':
13 self.seq = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
14
15 def forward(self, x):
16 if self.type == 'ModuleList':
17 for idx in [2, 3, 4, 0, 1]:
18 x = self.linears[idx](x)
19 elif self.type == 'Sequential':
20 x = self.seq(x)
21 return self.type, x
22
23
24 def creat_model():
25 model = _resnet('resnet50', Bottleneck, [2, 0, 0, 0], False, False)
26 model.layer1[0].__delattr__('conv2')
27 model.layer1[0].__delattr__('bn2')
28 model.layer1[0].__delattr__('conv3')
29 model.layer1[0].__delattr__('bn3')
30 model.layer1[1].__delattr__('conv2')
31 model.layer1[1].__delattr__('bn2')
32 model.layer1[1].__delattr__('conv3')
33 model.layer1[1].__delattr__('bn3')
34 model.__delattr__('layer2')
35 model.__delattr__('layer3')
36 model.__delattr__('layer4')
37 return model
38
39
40 def test_model():
41 model = creat_model()
42 for names_and_children, children in zip(model.named_children(), model.children()):
43 i, j = names_and_children
44 k = children
45 print(i, id(j) == id(k))
46
47 for names_and_parameters, parameters in zip(model.named_parameters(), model.parameters()):
48 i, j = names_and_parameters
49 k = parameters
50 print(i, id(j) == id(k))
51
52 for names_and_modules, modules in zip(model.named_modules(), model.modules()):
53 i, j = names_and_modules
54 k = modules
55 print(i, id(j) == id(k))
56
57 for name, parameter in model.state_dict().items():
58 print(name)
59
60
61 def test_container():
62 x = torch.randn([2, 10])
63 model = model_with_container('ModuleList')
64 # model = model_with_container('Sequential')
65 print(model, model(x), sep='\n')
66
67
68 if __name__ == '__main__':
69 test_model()
70 test_container()