PyTorch:stack + reshape与cat之间的异同

value_query = torch.stack([prev_query, now_query], 1)
.reshape(bs*2, num_query, -1).permute(1, 0, 2)

value_query1 = torch.stack([prev_query, now_query], 2)
.reshape(num_query, bs*2, -1)

#而非
value_query1 = torch.stack([prev_query, now_query], 1)
.reshape(num_query, bs*2, -1)

posted @ 2022-09-21 20:59  龙雪  阅读(2)  评论(0编辑  收藏  举报  来源