a.shape() = (2,3) b.shape() = (2,3) torch.cat([a,b], dim=0).shape() = (4,3) torch.cat([a,b], dim=1).shape() = (2,6)