Pytorch中 batch_first 选择True/False的区别

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default: False

1
2
3
4
5
<br>rnn = nn.RNN(input_size=4,hidden_size=3,num_layers=2,batch_first=True)
input = torch.randn(1,5,4)
output , h_n = rnn(input)
print(output.shape)
print(h_n.shape)

输出结果:

#output

tensor([[[-0.4026, -0.2417, -0.1307],
             [-0.0122, 0.4269, -0.7256],
             [ 0.2228, 0.7731, -0.9092],
             [-0.3735, 0.4446, -0.6930],
             [-0.1539, 0.5937, -0.8616]]], grad_fn=<TransposeBackward1>)

#h_n
tensor([[[-0.5664, 0.0416, -0.9316]],
             [[-0.1539, 0.5937, -0.8616]]], grad_fn=<StackBackward0>)

1
2
input.Size([1, 5, 4])  # batch_size, seq_len, feature
output.Size([1, 5, 3])  # batch_size, seq_len, hidden_sizeh_n.Size([2, 1, 3])  # num_layers , batch_size, hidden_sizebatch_size = 1 num_layers = 2 hidden_size = 3 input_size = 4 seq_len = 5

当batch_first = False

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
rnn = nn.RNN(input_size=4,hidden_size=3,num_layers=2,batch_first=False)
input = torch.randn(1,5,4)
output , h_n = rnn(input)
print(output.shape)
print(h_n.shape)
 
输出结果:
input.Size([1, 5, 4]) # seq_len, batch_size, feature
output.Size([1, 5, 3]) # seq_len, batch_size, hidden-size
h_n.Size([2, 5, 3])  # num_layers, batch_size, hidden-size
 
seq_len = 1  num_layers = 2 hidden_size = 3  input_size = 4 batch_size = 5
 
#output
tensor([[[-0.68840.0477, -0.3248],
         [-0.5575, -0.0757, -0.4916],
         [-0.66450.2197, -0.4582],
         [-0.68200.1047, -0.4033],
         [-0.66240.0487, -0.3798]]], grad_fn=<StackBackward0>)
#h_n
tensor([[[ 0.4404, -0.45110.2594],
         [ 0.84870.39870.2429],
         [ 0.0287, -0.7793, -0.4574],
         [-0.4603, -0.77940.4563],
         [ 0.4304, -0.34240.1715]],
 
        [[-0.68840.0477, -0.3248],
         [-0.5575, -0.0757, -0.4916],
         [-0.66450.2197, -0.4582],
         [-0.68200.1047, -0.4033],
         [-0.66240.0487, -0.3798]]], grad_fn=<StackBackward0>)

  

  

posted @   华小电  阅读(637)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示