Pytorch tensor 常用操作 备查
生成值范围在[0,1]的张量
a = torch.rand(5, 2, 3)
元素总数
a.numel() # 30
获取张量的形状
a.size()
a.shape
取值(不参与计算)
a.data # 不安全
a.detach() # 推荐
平均值,最大值,最小值
a.detach().mean().item()
a.detach().max().item()
a.detach().min().item()
最大值对应索引
torch.argmax(a.detach()).item()
torch.max(a.detach()).indices.item()
# 存在dim参数时, 结果的形状为去掉dim对应维度后的形状
dim = 0
m = torch.argmax(a, dim)
print('m', m.size(), m)
tensor 转 ndarray
# 数据在CPU
a.detach().numpy()
# 数据在GPU
a.cpu().detach().numpy()
ndarray 转 tensor
b=torch.from_numpy(a)
变形
a=a.view(3,10) # 3个10
a=a.view(5,-1) # 5个N
a=a.view(-1,5) # N个5
去除形状为1的维度
a = torch.rand(5, 1, 3)
a = a.squeeze()
升维度,加入的维度形状为1
a = torch.rand(5, 2, 3)
# 加在最前
a=a.unsqueeze(0)
# 加在最后
a=a.unsqueeze(-1)
拼接
# 按维数0拼接(竖着拼)
C = torch.cat((A, B), 0)
# 按维数1拼接(横着拼)
C = torch.cat((A, B), 1)