Pytorch自定义参数层
注意,一般官方接口都带有可导功能,如果你实现的层不具有可导功能,就需要自己实现梯度的反向传递。
1 官方Linear层: 2 class Linear(Module): 3 def __init__(self, in_features, out_features, bias=True): 4 super(Linear, self).__init__() 5 self.in_features = in_features 6 self.out_features = out_features 7 self.weight = Parameter(torch.Tensor(out_features, in_features)) 8 if bias: 9 self.bias = Parameter(torch.Tensor(out_features)) 10 else: 11 self.register_parameter('bias', None) 12 self.reset_parameters() 13 14 def reset_parameters(self): 15 stdv = 1. / math.sqrt(self.weight.size(1)) 16 self.weight.data.uniform_(-stdv, stdv) 17 if self.bias is not None: 18 self.bias.data.uniform_(-stdv, stdv) 19 20 def forward(self, input): 21 return F.linear(input, self.weight, self.bias) 22 23 def extra_repr(self): 24 return 'in_features={}, out_features={}, bias={}'.format( 25 self.in_features, self.out_features, self.bias is not None 26 )
1 实现view层 2 class Reshape(nn.Module): 3 def __init__(self, *args): 4 super(Reshape, self).__init__() 5 self.shape = args 6 7 def forward(self, x): 8 return x.view((x.size(0),)+self.shape)
1 实现LinearWise层 2 class LinearWise(nn.Module): 3 def __init__(self, in_features, bias=True): 4 super(LinearWise, self).__init__() 5 self.in_features = in_features 6 7 self.weight = nn.Parameter(torch.Tensor(self.in_features)) 8 if bias: 9 self.bias = nn.Parameter(torch.Tensor(self.in_features)) 10 else: 11 self.register_parameter('bias', None) 12 self.reset_parameters() 13 14 def reset_parameters(self): 15 stdv = 1. / math.sqrt(self.weight.size(0)) 16 self.weight.data.uniform_(-stdv, stdv) 17 if self.bias is not None: 18 self.bias.data.uniform_(-stdv, stdv) 19 20 def forward(self, input): 21 x = input * self.weight 22 if self.bias is not None: 23 x = x + self.bias 24 return x
快去成为你想要的样子!