Pytorch中 batch_first 选择True/False的区别
-
batch_first – If
True
, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default:False
1 2 3 4 5 | <br>rnn = nn.RNN(input_size = 4 ,hidden_size = 3 ,num_layers = 2 ,batch_first = True ) input = torch.randn( 1 , 5 , 4 ) output , h_n = rnn( input ) print (output.shape) print (h_n.shape) |
输出结果:
#output
tensor([[[-0.4026, -0.2417, -0.1307],
[-0.0122, 0.4269, -0.7256],
[ 0.2228, 0.7731, -0.9092],
[-0.3735, 0.4446, -0.6930],
[-0.1539, 0.5937, -0.8616]]], grad_fn=<TransposeBackward1>)
#h_n
tensor([[[-0.5664, 0.0416, -0.9316]],
[[-0.1539, 0.5937, -0.8616]]], grad_fn=<StackBackward0>)
1 2 | input .Size([ 1 , 5 , 4 ]) # batch_size, seq_len, feature output.Size([ 1 , 5 , 3 ]) # batch_size, seq_len, hidden_sizeh_n.Size([2, 1, 3]) # num_layers , batch_size, hidden_sizebatch_size = 1 num_layers = 2 hidden_size = 3 input_size = 4 seq_len = 5 |
当batch_first = False
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | rnn = nn.RNN(input_size = 4 ,hidden_size = 3 ,num_layers = 2 ,batch_first = False ) input = torch.randn( 1 , 5 , 4 ) output , h_n = rnn( input ) print (output.shape) print (h_n.shape) 输出结果: input .Size([ 1 , 5 , 4 ]) # seq_len, batch_size, feature output.Size([ 1 , 5 , 3 ]) # seq_len, batch_size, hidden-size h_n.Size([ 2 , 5 , 3 ]) # num_layers, batch_size, hidden-size seq_len = 1 num_layers = 2 hidden_size = 3 input_size = 4 batch_size = 5 #output tensor([[[ - 0.6884 , 0.0477 , - 0.3248 ], [ - 0.5575 , - 0.0757 , - 0.4916 ], [ - 0.6645 , 0.2197 , - 0.4582 ], [ - 0.6820 , 0.1047 , - 0.4033 ], [ - 0.6624 , 0.0487 , - 0.3798 ]]], grad_fn = <StackBackward0>) #h_n tensor([[[ 0.4404 , - 0.4511 , 0.2594 ], [ 0.8487 , 0.3987 , 0.2429 ], [ 0.0287 , - 0.7793 , - 0.4574 ], [ - 0.4603 , - 0.7794 , 0.4563 ], [ 0.4304 , - 0.3424 , 0.1715 ]], [[ - 0.6884 , 0.0477 , - 0.3248 ], [ - 0.5575 , - 0.0757 , - 0.4916 ], [ - 0.6645 , 0.2197 , - 0.4582 ], [ - 0.6820 , 0.1047 , - 0.4033 ], [ - 0.6624 , 0.0487 , - 0.3798 ]]], grad_fn = <StackBackward0>) |
分类:
Python
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人