pytorch入门--索引与切片
其他相关操作:https://blog.csdn.net/qq_43923588/article/details/108007534
本篇对torch的tensor索引与切片进行展示,包含:
- 根据维度进行索引
- [start : end : step]方式索引
- 索引的切片
- 使用 … 表示任意多的维度
- 使用 mask 进行索引
- 使用 take 进行索引
使用方法和含义均在代码的批注中给出,因为有较多的输出,所以设置输出内容的第一个值为当前print()方法所在的行
索引与切片
import torch
import numpy as np
import sys
loc = sys._getframe()
_ = '\n'
'''给定维度进行索引'''
a = torch.rand(4, 3, 8, 8)
print(loc.f_lineno, _, a[0].shape, _, a[0, 0].shape, _, a[0, 0, 2])
'''python中常用索引方式'''
'''[start:end:step]方式的索引'''
# 以a为例,取其前两张图片,包含其左边界值,不包含右边界
print(loc.f_lineno, _, a[:2].shape)
# 取其前两张图片,并对其通道进行索引,中有一个:表示取其全部值
print(loc.f_lineno, _, a[:2, :1].shape, _, a[:2, :1, :, :].shape)
# 取后面的图片
print(loc.f_lineno, _, a[2:].shape)
# 反向索引,例:[0, 1, 2, 3] 取[-1:]时渠道的值为3;取[-2:]时取到的值为2,3
print(loc.f_lineno, _, a[:, -1:].shape)
# 使用::进行切片,选取时设置间隔,第二个:后面为步长,若没有第二个:则默认间隔为1
print(loc.f_lineno, _, a[:, :, 0:8:2, 0:8:2].shape, _, a[:, :, ::2, ::2].shape)
'''给定具体索引的切片'''
# 在a的第0个维度,选取第0张和第2张图片,需要注意第二个index_select()的list参数必须是tensor类型
print(loc.f_lineno, _, a.index_select(0, torch.tensor([0, 2])).shape)
# 使用新变量接收索引后的值
b = torch.index_select(a, 0, torch.tensor([0, 2]))
print(loc.f_lineno, _, b.shape)
# 使用arange()生成tensor类型的list作为索引
print(loc.f_lineno, _, a.index_select(2, torch.arange(8)).shape)
'''使用...表示任意多的维度'''
# 表示a本身的所有维度
print(loc.f_lineno, _, a[...].shape)
# a[0]的等价
print(loc.f_lineno, _, a[0, ...].shape)
# 取第一张图片,中间全取,最后隔行取
print(loc.f_lineno, _, a[0, ..., ::2].shape)
'''使用mask进行索引,选取满足你所摄到的数值条件的值'''
c = torch.randn(3, 4)
print(loc.f_lineno, _, c)
# .ge()方法作用是选取大于大于0.5的元素,将其位置置1(True),其余置0(False)
mask = c.ge(0.5)
print(loc.f_lineno, _, mask)
# 使用mask索引,输出满足条件的位置的值
print(loc.f_lineno, _, torch.masked_select(c, mask))
'''take()函数的使用'''
# 先将tensor转换为一个list,之后根据索引选取list对应位置的值
d = torch.randn(2, 3)
print(loc.f_lineno, _, d)
print(loc.f_lineno, _, torch.take(d, torch.tensor([0, 2, 4])))
运行结果
10
torch.Size([3, 8, 8])
torch.Size([8, 8])
tensor([0.4954, 0.4472, 0.0453, 0.8851, 0.9130, 0.1559, 0.7304, 0.4253])
15
torch.Size([2, 3, 8, 8])
17
torch.Size([2, 1, 8, 8])
torch.Size([2, 1, 8, 8])
19
torch.Size([2, 3, 8, 8])
21
torch.Size([4, 1, 8, 8])
23
torch.Size([4, 3, 4, 4])
torch.Size([4, 3, 4, 4])
28
torch.Size([2, 3, 8, 8])
31
torch.Size([2, 3, 8, 8])
33
torch.Size([4, 3, 8, 8])
38
torch.Size([4, 3, 8, 8])
40
torch.Size([3, 8, 8])
42
torch.Size([3, 8, 4])
47
tensor([[ 0.4089, 0.1719, -0.9253, -0.6969],
[-0.0059, -0.3264, 0.4777, 1.2088],
[ 0.3023, 0.1969, 1.7225, -0.6645]])
50
tensor([[False, False, False, False],
[False, False, False, True],
[False, False, True, False]])
52
tensor([1.2088, 1.7225])
57
tensor([[ 0.7402, 1.6835, 1.9991],
[ 1.1355, 2.0515, -0.3010]])
58
tensor([0.7402, 1.9991, 2.0515])
Process finished with exit code 0