拼接tensor
torch.cat(tensors, dim)
: 沿指定维度拼接张量。
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]]) tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]]) # dim=0 表示沿着第一个维度(行的方向)进行连接。 concatenated_tensor = torch.cat([tensor1, tensor2], dim=0) tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]])
# dim=1 表示沿着第二个维度(列的方向)进行连接。
concatenated_tensor = torch.cat([tensor1, tensor2], dim=1)
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])