Pytorch Transformer 中 Position Embedding 的实现

Pytorch Transformer 中 Position Embedding 的实现

The Positional Encoding part in Transformer is a special part, it isn't part of the network module, it is added in the embedded of words after embedding, so, If we save the model parameters, we will not save this part, or to say, this part don't have parameters in module, the output of this part is from calculation.

Positional Encoding

In paper, the positional encoding is added to the input embeddings at the bottoms of the encoder and decoder stacks. In Pytorch, we use a special function to get, register_buffer. In positional encoding part, we first use:

self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

and we can see the source code of register_buffer():

def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
        r"""Adds a buffer to the module.

        This is typically used to register a buffer that should not to be
        considered a model parameter. For example, BatchNorm's ``running_mean``
        is not a parameter, but is part of the module's state. Buffers, by
        default, are persistent and will be saved alongside parameters. This
        behavior can be changed by setting :attr:`persistent` to ``False``. The
        only difference between a persistent buffer and a non-persistent buffer
        is that the latter will not be a part of this module's
        :attr:`state_dict`.

        Buffers can be accessed as attributes using given names.

        Args:
            name (string): name of the buffer. The buffer can be accessed
                from this module using the given name
            tensor (Tensor): buffer to be registered.
            persistent (bool): whether the buffer is part of this module's
                :attr:`state_dict`.

        Example::

            >>> self.register_buffer('running_mean', torch.zeros(num_features))

        """
        if persistent is False and isinstance(self, torch.jit.ScriptModule):
            raise RuntimeError("ScriptModule does not support non-persistent buffers")

        if '_buffers' not in self.__dict__:
            raise AttributeError(
                "cannot assign buffer before Module.__init__() call")
        elif not isinstance(name, torch._six.string_classes):
            raise TypeError("buffer name should be a string. "
                            "Got {}".format(torch.typename(name)))
        elif '.' in name:
            raise KeyError("buffer name can't contain \".\"")
        elif name == '':
            raise KeyError("buffer name can't be empty string \"\"")
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists".format(name))
        elif tensor is not None and not isinstance(tensor, torch.Tensor):
            raise TypeError("cannot assign '{}' object to buffer '{}' "
                            "(torch Tensor or None required)"
                            .format(torch.typename(tensor), name))
        else:
            self._buffers[name] = tensor
            if persistent:
                self._non_persistent_buffers_set.discard(name)
            else:
                self._non_persistent_buffers_set.add(name)

this function is a special function in nn.Module in Pytorch, I think the most important is the _non_persistent_buffers_set attribute of nn.Module, and this is not be a part of this module's :attr:state_dict. So, when we want to save the model by torch.save this part will not be saved.

Calculation

In this Paper, it use sine and cosine functions of different frequencies:

\[PE_{pos,2_i} = sin(\frac{pos}{10000^{2i/d_{model}}}) \\ PE_{pos,2_{i+1}} = cos(\frac{pos}{10000^{2i/d_{model}}}) \]

Using Pytorch, we cloud use:

def _get_sinusoid_encoding_table(self, n_position, d_hid):
        ''' Sinusoid position encoding table '''
        # TODO: make it with torch instead of numpy

        def get_position_angle_vec(position):
            # this part calculate the position In brackets
            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
        # [:, 0::2] are all even subscripts, is dim_2i
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

        return torch.FloatTensor(sinusoid_table).unsqueeze(0)

The last part is forward function:

def forward(self, enc_output):
	return enc_output + self.pos_table[:, :x.size(1)].clone().detach()

we will add the positional encoding to the output of word embedding.

posted @ 2021-08-26 11:48  虾野百鹤  阅读(4140)  评论(0编辑  收藏  举报