Pytorch的nn.Module类中定义的实例方法apply()
参考文档Module — PyTorch 1.7.0 documentation
1 @torch.no_grad() 2 def init_weights(m): 3 print(m) 4 if type(m) == nn.Linear: 5 m.weight.fill_(1.0) 6 print(m.weight) 7 net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 8 net.apply(init_weights)
net类及其子类都会调用 init_weights() 方法