TransformerEncoder中的语法

PositionalEncodeing

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        # 0::2 --> 偶数维度, 1::2 --> 奇数维度 
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]    # 加pe第0维中的[0:x的句长]
        return self.dropout(x)

\[PE_{pos,2i} = sin(\frac{pos}{10000^{2i/d_{model}}}) \]

\[PE_{pos,2i+i} = cos(\frac{pos}{10000^{2i/d_{model}}}) \]

div_term

div_term=$ e^{2i * (\frac{-log(10000)}{d_{model}})} = (\frac{1}{10000})^{\frac{2i}{d}}$

pe[:, 0, 0::2]

pe[:, 0, 0::2] = torch.sin(position * div_term)
Example:

pe = torch.zeros(5, 1, 8)
pe[:, 0, 0::2] = 1
pe:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.]]])
-->

# 第三维的 (0)th, (0+2)th, (2+2)th, (4+2)th = 1

tensor([[[1., 0., 1., 0., 1., 0., 1., 0.]],

        [[1., 0., 1., 0., 1., 0., 1., 0.]],

        [[1., 0., 1., 0., 1., 0., 1., 0.]]])

self.register_buffer()

self.register_buffer('per', pe)
  • 将tensor pe 注册成buffer, 不会有梯度传播给它,但能被模型的 state_dict 记录下来
  • buffer的更新在forward中,optim.step只能更新nn.parameter类型的参数
  • 网络存储时也会将buffer存下,当网络load模型时,会将存储的模型的buffer也进行赋值。

data.uniform_

  • 权重初始化
    def init_weights(self) -> None:
    initrange = 0.1
    self.encoder.weight.data.uniform_(-initrange, initrange)
    self.decoder.bias.data.zero_()
    self.decoder.weight.data.uniform_(-initrange, initrange)

t()

def batchify(data:Tensor, bsz:int) -> Tensor:

  seq_len = data.size(0) // bsz
  data = data[:seq_len * bsz]
  # t.() 转置。 [bsz, seq_len] -> [seq_len, bsz]
  data = data.view(bsz,seq_len).t().contiguous()
  return data.to(device)
posted @ 2022-03-18 16:42  ArdenWang  阅读(189)  评论(0编辑  收藏  举报