pytorch---初始化
在深度学习中参数的初始化十分重要,良好的初始化能让模型更快收敛,并达到更高水平,而糟糕的初始化则可能使得模型迅速瘫痪。PyTorch中nn.Module的模块参数都采取了较为合理的初始化策略,因此一般不用我们考虑,当然我们也可以用自定义初始化去代替系统的默认初始化。而当我们在使用Parameter时,自定义初始化则尤为重要,因t.Tensor()返回的是内存中的随机数,很可能会有极大值,这在实际训练网络中会造成溢出或者梯度消失。PyTorch中nn.init
模块就是专门为初始化而设计,如果某种初始化策略nn.init
不提供,用户也可以自己直接初始化。
# 利用nn.init初始化
from torch.nn import init
linear = nn.Linear(3, 4)
t.manual_seed(1)
# 等价于 linear.weight.data.normal_(0, std)
init.xavier_normal_(linear.weight)
# 直接初始化
import math
t.manual_seed(1)
# xavier初始化的计算公式
std = math.sqrt(2)/math.sqrt(7.)
linear.weight.data.normal_(0,std)
# 对模型的所有参数进行初始化
for name, params in net.named_parameters():
if name.find('linear') != -1:
# init linear
params[0] # weight
params[1] # bias
elif name.find('conv') != -1:
pass
elif name.find('norm') != -1:
pass
补充
xavier初始化
torch.nn.init.xavier_uniform(tensor, gain=1)
对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。
初始化服从均匀分布U(−a,a)U(−a,a),其中a=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√×3–√a=gain×2/(fan_in+fan_out)×3,该初始化方法也称Glorot initialisation。
参数:
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:可选择的缩放参数
例如:
w = torch.Tensor(3, 5)
nn.init.xavier_uniform(w, gain=nn.init.calculate_gain('relu'))
torch.nn.init.xavier_normal(tensor, gain=1)
对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。初始化服从高斯分布N(0,std)N(0,std),其中std=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√std=gain×2/(fan_in+fan_out),该初始化方法也称Glorot initialisation。
参数:
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:可选择的缩放参数
例如:
w = torch.Tensor(3, 5)
nn.init.xavier_normal(w)
另外在torch.Tensor下还定义了一些in-place的函数:
-
torch.Tensor.bernoulli_()
- in-place version oftorch.bernoulli()
,伯努利分布 -
torch.Tensor.cauchy_()
- numbers drawn from the Cauchy distribution,柯西分布 -
torch.Tensor.exponential_()
- numbers drawn from the exponential distribution,指数分布 -
torch.Tensor.geometric_()
- elements drawn from the geometric distribution,几何分布 -
torch.Tensor.log_normal_()
- samples from the log-normal distribution,对数正太分布 -
torch.Tensor.normal_()
- in-place version oftorch.normal()
,正太分布 -
torch.Tensor.random_()
- numbers sampled from the discrete uniform distribution,均匀分布 -
torch.Tensor.uniform_()
- numbers sampled from the continuous uniform distribution,连续均匀分布每个的参数不同,像均匀分布等有均值和方差。