Pytorch:Tensor 张量操作
张量操作
一、张量的拼接与切分
1.1 torch.cat()
功能:将张量按维度dim进行拼接
tensors:张量序列
dim:要拼接的维度
1.2 torch.stack()
功能:在新创建的维度的上进行拼接
tensors:张量序列
dim:要拼接的维度(如果dim为新的维度,则新增一个维度进行拼接,新维度只能高一维)
1.3 torch.chunk()
功能:将张量按维度进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份小于其他张量;整除时令商为向上取整的数,如7/3=2.333,取整为3
input:要切分的张量
chunks:要切分的份数
dim:要切分的维度
1.4 torch.split()
功能:将张量按维度进行平均切分
返回值:张量列表
input:要切分的张量
split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分(注意list的各元素之和需等于维度上的长度)
dim:要切分的维度
二、张量索引
2.1 torch.index_select()
功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量
input:要索引的张量
dim:要索引的维度
index:要索引数据的序号(注意index的数据类型要为torch.long,float会报错)
2.2 torch.masked_select()
功能:按mask中的True进行索引
返回值:一维张量
input:要索引的张量
mask:与input同形状的布尔类型张量(mask的生成可以通过比较大小关系得出,le为小于等于,详见图英文注释)
三、张量变换
3.1 torch.reshape()
功能:变换张量形状
注意事项:当张量在内存中是连续时,新张量与input共享数据内存
input:要变换的张量
size:新张量的形状(形状中若有-1,则该处的值有其他维数及总数来计算得出)
3.2 torch.transpose()
功能:交换两个张量的维度
input:要交换的张量
dim0:要交换的维度
dim1:要交换的维度
3.3 torch.t()
功能:2维张量转置,对矩阵而言,等价于torch.transpose(input,0,1)
3.4 torch.squeeze()
功能:压缩长度为1的维度(轴)
dim:若为None,移除所有长度为1的轴;如果指定维度,当且仅当该轴长度为1时,可以被移除
3.5 torch.unsqueeze()
功能:依据dim扩展维度
dim:扩展的维度
三、张量数学运算
主要可分为三类:
1.加减乘除 2. 对数、指数、幂函数 3.三角函数
其中加法比较特殊
torch.add()
功能:逐元素计算该式 input+alpha*other(为了简便于梯度下降的运算)
input:第一个张量
alpha:乘项因子
other:第二个张量
另外的拓展还有