Pytorch中 batch_first 选择True/False的区别

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default: False


rnn = nn.RNN(input_size=4,hidden_size=3,num_layers=2,batch_first=True) input = torch.randn(1,5,4) output , h_n = rnn(input) print(output.shape) print(h_n.shape)

输出结果:

#output

tensor([[[-0.4026, -0.2417, -0.1307],
             [-0.0122, 0.4269, -0.7256],
             [ 0.2228, 0.7731, -0.9092],
             [-0.3735, 0.4446, -0.6930],
             [-0.1539, 0.5937, -0.8616]]], grad_fn=<TransposeBackward1>)

#h_n
tensor([[[-0.5664, 0.0416, -0.9316]],
             [[-0.1539, 0.5937, -0.8616]]], grad_fn=<StackBackward0>)


input.Size([1, 5, 4]) # batch_size, seq_len, feature output.Size([1, 5, 3]) # batch_size, seq_len, hidden_size
h_n.Size([2, 1, 3])  # num_layers , batch_size, hidden_size

batch_size = 1 num_layers = 2 hidden_size = 3 input_size = 4 seq_len = 5

当batch_first = False

rnn = nn.RNN(input_size=4,hidden_size=3,num_layers=2,batch_first=False)
input = torch.randn(1,5,4)
output , h_n = rnn(input)
print(output.shape)
print(h_n.shape)

输出结果:
input.Size([1, 5, 4]) # seq_len, batch_size, feature
output.Size([1, 5, 3]) # seq_len, batch_size, hidden-size
h_n.Size([2, 5, 3])  # num_layers, batch_size, hidden-size

seq_len = 1  num_layers = 2 hidden_size = 3  input_size = 4 batch_size = 5

#output
tensor([[[-0.6884,  0.0477, -0.3248],
         [-0.5575, -0.0757, -0.4916],
         [-0.6645,  0.2197, -0.4582],
         [-0.6820,  0.1047, -0.4033],
         [-0.6624,  0.0487, -0.3798]]], grad_fn=<StackBackward0>)
#h_n
tensor([[[ 0.4404, -0.4511,  0.2594],
         [ 0.8487,  0.3987,  0.2429],
         [ 0.0287, -0.7793, -0.4574],
         [-0.4603, -0.7794,  0.4563],
         [ 0.4304, -0.3424,  0.1715]],

        [[-0.6884,  0.0477, -0.3248],
         [-0.5575, -0.0757, -0.4916],
         [-0.6645,  0.2197, -0.4582],
         [-0.6820,  0.1047, -0.4033],
         [-0.6624,  0.0487, -0.3798]]], grad_fn=<StackBackward0>)

  

  

posted @ 2022-04-25 08:49  华小电  阅读(607)  评论(0编辑  收藏  举报