model.apply(weights_init_normal)

model.apply(weights_init_normal)方法

应用把方法应用于每一个module,这里意思是进行初始化

  1. def weights_init_normal(m): 
  2. classname = m.__class__.__name__ 
  3. if classname.find("Conv") != -1: 
  4. torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 
  5. elif classname.find("BatchNorm2d") != -1: 
  6. torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 
  7. torch.nn.init.constant_(m.bias.data, 0.0) 

这里的意思是选择module是conv或者是batchNorm2d的层进行初始化

posted @ 2020-03-10 10:59  vivia~  阅读(4046)  评论(0编辑  收藏  举报