Pytorch tensor 常用操作 备查

生成值范围在[0,1]的张量

a = torch.rand(5, 2, 3)

元素总数

a.numel()  # 30

获取张量的形状

a.size()
a.shape

取值(不参与计算)

a.data # 不安全
a.detach() # 推荐

平均值,最大值,最小值

a.detach().mean().item()
a.detach().max().item()
a.detach().min().item()

最大值对应索引

torch.argmax(a.detach()).item()
torch.max(a.detach()).indices.item()

# 存在dim参数时, 结果的形状为去掉dim对应维度后的形状
dim = 0
m = torch.argmax(a, dim)
print('m', m.size(), m)

tensor 转 ndarray

# 数据在CPU
a.detach().numpy()
# 数据在GPU
a.cpu().detach().numpy()

ndarray 转 tensor

b=torch.from_numpy(a)

变形

a=a.view(3,10)  # 3个10
a=a.view(5,-1)  # 5个N
a=a.view(-1,5)  # N个5

去除形状为1的维度

a = torch.rand(5, 1, 3)
a = a.squeeze()

升维度,加入的维度形状为1

a = torch.rand(5, 2, 3)
# 加在最前
a=a.unsqueeze(0)
# 加在最后
a=a.unsqueeze(-1)

拼接

# 按维数0拼接(竖着拼)
C = torch.cat((A, B), 0)

# 按维数1拼接(横着拼)
C = torch.cat((A, B), 1)
posted @ 2021-05-02 15:34  太晓  阅读(127)  评论(0编辑  收藏  举报