model.apply(weights_init_normal)
model.apply(weights_init_normal)方法
应用把方法应用于每一个module,这里意思是进行初始化
- def weights_init_normal(m):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
- elif classname.find("BatchNorm2d") != -1:
- torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
- torch.nn.init.constant_(m.bias.data, 0.0)
这里的意思是选择module是conv或者是batchNorm2d的层进行初始化