pytorch(2)----基本数据类型与模块
Tensor
张量,包含单一数据类型元素的矩阵。
基本内容
1、数据初始化及数据类型转换
2、组合: torch.cat() 按照某一个维度进行拼接,总维度数目不变
torch.stack() 按照制定维度进行叠加,新增维度
3、分块:torch.chunk() 指定分块数量
torch.split() 指定每块数量
4、 索引、torch.masked_select
5、维度变化:.view().resize() .reshape()
6、.unsqueeze()、.squeeze()
7、Tensor 的排序: .sort()
8、Tensor 的广播机制
示例代码
1、
# 数据类型 及转换 # 默认的数据类型为 torch.FloatTensor # x 实际为torch.FloatTensor x = torch.Tensor(2,2) #2行2列、未初始化 x = torch.rand(2,2) #均匀分布 [[0,1) print(x) print(x.dtype) # 使用 int() double() float() 等直接进行数据类型转换 b = x.double() print(b.dtype) print(type(b)) #使用 .type()函数进行 类型转换 c = x.type(torch.IntTensor) print(c.dtype) print(c) #使用 .type_as()函数转换类型更加 方便 d = x.type_as(c) print(d.dtype) print(d) #其他初始化方法 print('\n其他初始化方法') # 直接给值 c1 = torch.Tensor([[2,3,4],[1,4,5]]) print(c1.shape) print(c1.dtype) # ones() eye() zeros() c2 =torch.eye(5) print(c2) print(c2.dtype) #randn() 标准正太分布 随机数 c3 = torch.randn(4,4) # torch.arange(start,end, step) 生成一维向量 [start,end) c4 = torch.arange(1,6,2) print(c4) #元素个数 print(c1.numel())
2、
# 组合 # torch.cat() 按照某一个维度进行拼接,总维度数目不变 print("torch.cat()----") a = torch.Tensor([[1,2],[3,4]]) print(a) b = torch.Tensor([[5,6],[7,8]]) print(b) # 按照第一维进行拼接 c1 = torch.cat([a,b],0) print(c) # 按照第二维进行拼接 c2 = torch.cat([a,b],1) print(c) # 组合 # torch.stack() 按照制定维度进行叠加,新增维度 print("\n torch.stack()----") print(a) print(b) c3 = torch.stack([a,b],0) print(c3) print(c3.shape) c4 = torch.stack([a,b],1) print(c4) print(c4.shape)
3、
# 分块 # torch.chunk() 制定分块数量 a = torch.Tensor([[1,2,3],[4,5,6]]) print(a) print(a.shape) # 沿着第0维,分成2块 print( torch.chunk(a,2,0) ) print( torch.chunk(a,2,1) ) #分配不均时,前面的个数多于后面 # 分块 # torch.split() 指定每块数量 print(a) print(a.shape) # 沿着第0维,每块的个数为2 print( torch.split(a,2,0) ) print( torch.split(a,2,1) )
4、
# 索引 a = torch.Tensor([[0,1],[6,7]]) # 按下标索引 print(a[0]) print(a[0,1]) # 比较 true 为1,false 为0 print(a>0) # 选择符合条件的元素返回 print(torch.masked_select(a,a>0)) print(a[a>0])
5、
# 维度变化 a = torch.arange(1,5) print(a) print(a.view(2,2)) print(a.resize(4,1)) print(a) print(a.reshape(2,2)) # 原地操作 print(a.resize_(4,1)) print(a)
6、
# 增加维度 # .unsqueeze a = torch.arange(1,4) print(a) print(a.shape) # 将第0维变成1 b1 = a.unsqueeze(0) print(b1) print(b1.shape) #将第1维变成1 b2 = a.unsqueeze(1) print(b2) print(b2.shape) # 减少维度 c1 = b2.squeeze(1) print(c1) print(c1.shape)
7、
# Tensor 的排序 # 按照第0维进行排序,True为降序,false为升序 a = torch.randn(3,3) print(a) b=a.sort(0,True) print("排序结果\n",b[0]) print(a) print("排序结果索引\n",b[1]) # 按照第1维进行排序,True为降序,false为升序 b1=a.sort(1,True) print("排序结果\n",b1[0]) print("排序结果索引\n",b1[1]) # .max .min print(a) c = a.max(0) # 按照第0维: 选出每一列的最大值 print(c[0]) print(c[1])
8、
# Tensor 的广播机制 # 条件: 任一个tensor 至少有一个维度,且从尾到头部遍历整个tensor维度时 a = torch.ones(3,1,2) b = torch.ones(2,1) c = a+b print(c) print(c.shape) d = torch.ones(2,3) #c2 = a+d #error