神经网络中的权重初始化方式和pytorch应用

深度学习模型中的权重初始化对模型的训练效果有很大的影响,对预训练模型的研究就是为了在大模型上先训练出较好的权重,然后再放到不同的小任务上微调。

对于不加载预训练的模型,仍然可以通过定义模型权重初始化的方式来使得模型获得较好的效果,以下介绍不同的权重初始化方法、适用场景及效果。

计算增益

对于线性

nonlinearity gain
Linear / Identity \(1\)
Conv{1,2,3}D \(1\)
Sigmoid \(1\)
Tanh \(\frac{5}{3}\)
ReLU \(\sqrt{2}\)
Leaky Relu \(\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}\)
SELU \(\frac{3}{4}\)

常数初始化

torch.nn.init.constant_(tensor, val)
按照常数val初始化tensor。

特别的,val为0和1分别有torch.nn.init.zeros_(tensor)torch.nn.init.ones_(tensor)

均匀分布初始化

torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
按照\(U(a,b)\)的均匀分布初始化tensor。

正态分布初始化

torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
按照\(N(mean,std^2)\)的均匀分布初始化tensor。

Xavier初始化

均匀分布(glorot初始化)

torch.nn.init.xavier_uniform_(tensor, gain=1.0)
按照\(U(-a,a)\)的均匀分布初始化tensor,其中

\[a = gain \times \sqrt{\frac{6}{fan\_in + fan\_out}} \]

正态分布

torch.nn.init.xavier_normal_(tensor, gain=1.0)

按照\(N(0,std^2)\)的均匀分布初始化tensor,其中

\[std = gain \times \sqrt{\frac{2}{fan\_in + fan\_out}} \]

Kaiming初始化

均匀分布

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

按照\(U(-a,a)\)的均匀分布初始化tensor,其中

\[a = gain \times \sqrt{\frac{3}{fan\_mode}} \]

正态分布

torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

按照\(N(0,std^2)\)的均匀分布初始化tensor,其中

\[std = \sqrt{\frac{gain}{fan\_mode}} \]

具体应用

pytorch中的torch.nn.init模块中有多种初始化的方法,可以显式地定义,以下是一个例子:

 def init_weights(self):
     for m in self.modules():
         if isinstance(m, GCNConv):
             m.weight.data = init.xavier_uniform(
                 m.weight.data, gain=torch.nn.init.calculate_gain("relu")
             )
             if m.bias is not None:
                 m.bias.data = init.constant(m.bias.data, 0.0)

这个函数是模型类的成员函数,它表示的是检索这个类中的所有模块,如果有GCNConv类的话,就将对应的weights用xavier均匀分布的方法初始化,如果有bias的话用常数来初始化bias,在对象初始化的时候调用self.init_weights();就可以了。

一些问答或tips

1. How to use torch.nn.init.calculate_gain?
2. How to Initialize Weights in PyTorch
3. Weight Initialization Techniques in Neural Networks
4. 网络权重初始化方法总结(上):梯度消失、梯度爆炸与不良的初始化
5. 网络权重初始化方法总结(下):Lecun、Xavier与He Kaiming
6.pytorch-nn.init模块文档

posted @ 2022-11-18 15:32  阿莱慢慢来  阅读(346)  评论(0)    收藏  举报