PyTorch:nn.Embedding.weight和nn.Embedding的区别
太长不看版:
如果非直接使用nn.Embedding而使用nn.Embedding.weight来作为变量,其随机初始化方式是自带标准正态分布,即均值
,方差
的正态分布。
下面是论据
源代码:
import torch
from torch.nn.parameter import Parameter
from .module import Module
from .. import functional as F
class Embedding(Module):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=None):
if _weight is None:
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
'Shape of weight does not match num_embeddings and embedding_dim'
self.weight = Parameter(_weight)
def reset_parameters(self):
self.weight.data.normal_(0, 1)
if self.padding_idx is not None:
self.weight.data[self.padding_idx].fill_(0)
Embedding
这个类有个属性weight
,它是torch.nn.parameter.Parameter
类型的,作用就是存储真正的word embeddings。如果不给weight
赋值,Embedding
类会自动给他初始化,看上述代码第6~8行,如果属性weight
没有手动赋值,则会定义一个torch.nn.parameter.Parameter
对象,然后对该对象进行reset_parameters()
,看第21行,对self.weight
先转为Tensor
在对其进行normal_(0, 1)
(调整为$N(0, 1)$正态分布)。所以nn.Embeddig.weight
默认初始化方式就是N(0, 1)分布,即均值$\mu=0$,方差$\sigma=1$的标准正态分布。
下面将做的是验证nn.Embeddig.weight
某一行词向量的均值和方差,以便验证是否为标准正态分布。
注意:验证一行数字的均值为0,方差为1,显然不能说明该分布就是标准正态分布,只能是其必要条件,而不是充分条件,要想真正检测这行数字是不是正态分布,在概率论上有专门的较为复杂的方法,请查看概率论之假设检验。
import torch.nn as nn
# dim越大,均值、方差越接近0和1
dim = 800000
# 定义了一个(5, dim)的二维embdding
# 对于NLP来说,相当于是5个词,每个词的词向量维数是dim
# 每个词向量初始化为正态分布 N(0,1)(待验证)
embd = nn.Embedding(5, dim)
# type(embd.weight) is Parameter
# type(embd.weight.data) is Tensor
# embd.weight.data[0]是指(5, dim)的word embeddings中取第1个词的词向量,是dim维行向量
weight = embd.weight.data[0].numpy()
print("weight: {}".format(weight))
weight_sum = 0
for w in weight:
weight_sum += w
mean = weight_sum / dim
print("均值: {}".format(mean))
square_sum = 0
for w in weight:
square_sum += (mean - w) ** 2
print("方差: {}".format(square_sum / dim))
代码输出:
weight: [-0.65507996 0.11627434 -1.6705967 ... 0.78397447 ... -0.13477565]
均值: 0.0006973597864689242
方差: 1.0019535550544454
可见,均值接近0,方差接近1,从这里也可以反映出nn.Embeddig.weight
是标准正态分布$N(0, 1)$。