Python pytorch 坐标系变换与维度转换

前言

深度学习中经常要用到张量坐标系变换与维度转换,因此记录一下,避免混淆

坐标系变换

坐标系变换(矩阵转置),主要是调换tensor/array的维度

pytorch

import torch

def info(tensor):
    print(f"tensor: {tensor}")
    print(f"tensor size: {tensor.size()}")
    print(f"tensor is contiguous: {tensor.is_contiguous()}")
    print(f"tensor stride: {tensor.stride()}")

tensor = torch.rand([1,2,3])
info(tensor)

# output:
# tensor: tensor([[[0.9516, 0.2289, 0.0042],
#          [0.2808, 0.4321, 0.8238]]])
# tensor size: torch.Size([1, 2, 3])
# tensor is contiguous: True
# tensor stride: (6, 3, 1)

per_tensor = tensor.permute(1,2,0)
info(per_tensor)

# output:
# tensor: tensor([[[0.9516, 0.2808],
#          [0.2289, 0.4321],
#          [0.0042, 0.8238]]])
# tensor size: torch.Size([1, 3, 2])
# tensor is contiguous: False
# tensor stride: (6, 1, 3)

numpy

import numpy as np

def np_info(array):
    print(f"array: {array}")
    print(f"array size: {array.shape}")
    print(f"array is contiguous: {array.flags['C_CONTIGUOUS']}")
    print(f"array stride: {array.strides}")

array = np.random.rand(1,2,3)
np_info(array)

# output:
# array: [[[0.58227139 0.32251543 0.12221412]
#   [0.72647191 0.42323578 0.65290986]]]
# array size: (1, 2, 3)
# array is contiguous: True
# array stride: (48, 24, 8)

trans_array = np.transpose(array, (0,2,1))
np_info(trans_array)

# output:
# array: [[[0.58227139 0.72647191]
#   [0.32251543 0.42323578]
#   [0.12221412 0.65290986]]]
# array size: (1, 3, 2)
# array is contiguous: False
# array stride: (48, 8, 24)

所以对于高维的tensor来说,其实并没有改变数据的相对位置,只是旋转了这个data的(超)立方体,即改变(超)立方体的观察角度

维度变换

tensor.view()

view()主要是将tensor转化为想要的张量尺寸,但并不影响contiguous属性
view()相当于tensor的一个引用,通过它会直接对原tensor进行操作,不会产生拷贝,输出和输入是共享内部存储的

view_tensor = tensor.view(3,2,1)
info(view_tensor)

# output:
# tensor: tensor([[[0.9516],
#          [0.2289]],
# 
#         [[0.0042],
#          [0.2808]],
# 
#         [[0.4321],
#          [0.8238]]])
# tensor size: torch.Size([3, 2, 1])
# tensor is contiguous: True
# tensor stride: (2, 1, 1)

但当对contiguous为false的tensor进行view操作时,则会报错

view_per_tensor  = per_tensor.view(2,3) 

#output:
# ---------------------------------------------------------------------------
# RuntimeError                              Traceback (most recent call last)
# /tmp/ipykernel_388070/1679121630.py in <module>
# ----> 1 view_per_tensor  = per_tensor.view(2,3)
#       2 # info(per_tensor)
#       3 info(view_per_tensor)
#       4 print(view_per_tensor.data_ptr() == per_tensor.data_ptr())

# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

tensor.reshape()

torch.Tensor.reshape()可以对任意tensor进行操作,相当于torch.Tensor.view() + torch.Tensor.contiguous().view(),也就是说,reshape操作也不一定会开辟新的内存空间,如果tensor是连续的话,实际上调用的view的实现,而当tensor不连续且步长不兼容的时候,就会对tensor进行深拷贝。

reshape_per_tensor = per_tensor.reshape(2,3) 
info(reshape_per_tensor)

# output:
# tensor: tensor([[0.9384, 0.9049, 0.8476],
#         [0.5196, 0.7949, 0.0637]])
# tensor size: torch.Size([2, 3])
# tensor is contiguous: True
tensor stride: (3, 1)

Ref

  1. https://blog.csdn.net/wulele2/article/details/127337439
  2. https://blog.csdn.net/wxfighting/article/details/122758553
posted @   liuliu55  阅读(186)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· .NET Core 中如何实现缓存的预热?
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 如何调用 DeepSeek 的自然语言处理 API 接口并集成到在线客服系统
· 【译】Visual Studio 中新的强大生产力特性
点击右上角即可分享
微信分享提示