tensor 3维分块乘法
a = torch.range(1,4) a = a.reshape(2,1,2) b= torch.range(1,12) b = b.reshape(2,2,3) c = torch.bmm(a,b) print('c') print(c) print(c.shape) d = torch.zeros(2,1,3) for i in range(2): a_ = a[i,:,:] b_ = b[i,:,:] c_ = torch.mm(a_,b_) d[i,:,:] =c_ print('d') print(d) print(d.shape)
torch.bmm 只能3维 https://blog.csdn.net/qq_40178291/article/details/100302375
torch.mm https://blog.csdn.net/da_kao_la/article/details/87484403