为什么要给网络参数赋初值
既然网络参数通过训练得到,那么其初值是否重要?设置初值不佳是否只影响收敛速度而不影响模型结果?网络参数是否可以设置为全0或者全1?
假设网络的参数W初值都是0,如下图所示,无论输入任何X,第一层的输出A将都为0,再向前传递到y也是0,使用误差函数调参时,每一层的梯度只与该层的输入和输出有关,由于a1,a2值相等,计算出的梯度调整的值,以及调整后的梯度也相等;第二次迭代也同理,由于a1,a2相等,w[2]中各单元的值也相等。因此该层有100个单元与1个单元没有差异,该问题被称为“对称性”问题。
试想将w设置成全1,则有a1=x1+x2,a2=x1+x2,a1与a2值仍然相同,对称性问题依然存在。由此,一般将参数设置为随机值。
设置成随机值还不够,还需要设置成较小的随机值,试想如果w的均值在0.5附近,某一层的输入输出都为500个元素,那么经过该层乘和加的运算,输出约是输入值的250倍;如果存在多层,250x250x…,很快就爆炸了。如果在层后使用Sigmoid函数,将值映射到较小的空间,又会发生非线性激活函数的饱和问题,使收敛变慢。
因此,简单的方法是:
W = np.random.randn(o_dim, i_dim) * 0.01
np.zeros((o_dim, 1))
bias不导致对称性,一般设置为0。
常用的初值化方法
全0或全1的初始化方法不能使用,而随机初始化也存在一些问题,由于各层的输入和输出元素个数不同,这使得每一层输出数据的方差也不同,比如层输入500个元素和5个元素,同等大小的w,输出的大小可能差出百倍。不同层的调参将受到影响。
Xavier初始化
假设层的输入有三个元素x1,x2,x3,输入为y,权重分别是w1,w2,w3,
则y值为:
y=w1x1+w2x2+w3x3,在计算参数w的初值时考虑到输入该层元素的个数n:于是出现了Xavier方法。
Kaiming初始化
Xavier的问题是,它没有考虑到激活函数对输出数据分布的影响,它会带偏当前广泛使用的ReLU激活函数的结果,于是He Kaiming提取了针对ReLU激活函数的Kaiming初始化(有时也叫作He初始化)。
其原理是:由于ReLU过滤掉了0以下的输入值,因此,激活函数输出的均值将大于0,为解决这和问题,Kaiming方法修改了生成随机数时的标准差。
ReLU过滤掉了一半数据,因此分子乘2,分母是上一层的输出元素个数,即本层的输入元素个数。推导过程请见论文:《Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification》的Section 2.2。
尽管Kaiming初始化一开始主要针对ReLU激活函数优化,但是目前主流库中的Kaiming函数已经支持sigmoid等多种激活函数,可放心调用。
归一化层
Kaiming的目标也是保证各个层输入和输出数据的方差不变。由于后来归一化层被广泛使用,有效地缓解了均值和方差稳定的问题。因此,在使用归一化层的情况下,使用随机数始初化参数即可。
另外,在有些情况下无法使用归一化层,比如最常用的BN(Batch Normalization)在Batch中数据较少时效果不好,这种情况下就需要选用参数初始化。
在预训练/调优的场景中,一般使用预训练的参数作为模型的初值。
Python参数初始化
在使用Pytorch构建网络时,torch.nn中提供了各种常用初始化方法,直接调用即可。下面列出用于初始化网络的某一层或某几层的常用代码。
def init_network_params(model, method='xavier', keywords, seed=123, debug=False):
for name, w in model.named_parameters():
init = False
for key in keywords:
if key in name:
init = True
if init:
if debug:
print('init layer params', name)
if 'weight' in name:
if method == 'xavier':
nn.init.xavier_normal_(w)
elif method == 'kaiming':
nn.init.kaiming_normal_(w)
else:
nn.init.normal_(w)
elif 'bias' in name:
nn.init.constant_(w, 0)
else:
pass
Pytorch中如果不额外设置,线性层的初值设为:(详见torch.nn.modules.linear):
kaiming_uniform_(self.weight, a=math.sqrt(5))