pytorch flatten()

 

 torch.flatten(input, start_dim, end_dim).

举例:一个tensor 3*2* 2

start_dim=1  output 3*4

start_dim=0 end_dim=1.    6*2

如果没有后面两个参数直接变为一维的

 

posted @ 2022-01-29 14:43  Tomorrow1126  阅读(79)  评论(0编辑  收藏  举报