Pytorch Transformer 中 Position Embedding 的实现
Pytorch Transformer 中 Position Embedding 的实现
The Positional Encoding part in Transformer is a special part, it isn't part of the network module, it is added in the embedded of words after embedding, so, If we save the model parameters, we will not save this part, or to say, this part don't have parameters in module, the output of this part is from calculation.
Positional Encoding
In paper, the positional encoding is added to the input embeddings at the bottoms of the encoder and decoder stacks. In Pytorch, we use a special function to get, register_buffer
. In positional encoding part, we first use:
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
and we can see the source code of register_buffer()
:
def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
r"""Adds a buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting :attr:`persistent` to ``False``. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module's
:attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Args:
name (string): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor): buffer to be registered.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
Example::
>>> self.register_buffer('running_mean', torch.zeros(num_features))
"""
if persistent is False and isinstance(self, torch.jit.ScriptModule):
raise RuntimeError("ScriptModule does not support non-persistent buffers")
if '_buffers' not in self.__dict__:
raise AttributeError(
"cannot assign buffer before Module.__init__() call")
elif not isinstance(name, torch._six.string_classes):
raise TypeError("buffer name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
raise KeyError("buffer name can't contain \".\"")
elif name == '':
raise KeyError("buffer name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._buffers:
raise KeyError("attribute '{}' already exists".format(name))
elif tensor is not None and not isinstance(tensor, torch.Tensor):
raise TypeError("cannot assign '{}' object to buffer '{}' "
"(torch Tensor or None required)"
.format(torch.typename(tensor), name))
else:
self._buffers[name] = tensor
if persistent:
self._non_persistent_buffers_set.discard(name)
else:
self._non_persistent_buffers_set.add(name)
this function is a special function in nn.Module
in Pytorch, I think the most important is the _non_persistent_buffers_set
attribute of nn.Module
, and this is not be a part of this module's :attr:state_dict
. So, when we want to save the model by torch.save
this part will not be saved.
Calculation
In this Paper, it use sine and cosine functions of different frequencies:
Using Pytorch, we cloud use:
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
# this part calculate the position In brackets
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
# [:, 0::2] are all even subscripts, is dim_2i
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
The last part is forward
function:
def forward(self, enc_output):
return enc_output + self.pos_table[:, :x.size(1)].clone().detach()
we will add the positional encoding to the output of word embedding.