torch.triu 测试
torch.triu(input, diagonal=0, *, out=None) → Tensor
返回一个上三角矩阵
参数:
input:输入的张量
diagonal:对角线
返回矩阵(二维张量)或矩阵批次 input 的上三角部分,结果张量 out 的其他元素设置为 0。
矩阵的上三角部分定义为对角线上和之上的元素。
参数 diagonal 控制要考虑的对角线。如果 diagonal = 0,则保留主对角线之上和之上的所有元素。正值不包括主对角线上方的对角线,类似地,负值包括主对角线下方的对角线。主对角线是
的索引集 ,其中 是矩阵的维度。
diagonal=0是保留主对角线以及上面的总和
diagonal=1是保留主对角线上面一行
diagonal=1是保留主对角线上面2行
diagonal=-1是保留主对角线下面一行以及以上
diagonal=-2是保留主对角线下面2行以及以上
import torch
len_s = 7
x = torch.rand((1, len_s, len_s))
print("====>>>>>>>>>>>>>x")
print(x)
print()
a0 = torch.triu(x, diagonal=0)
a1 = torch.triu(x, diagonal=1)
a2 = torch.triu(x, diagonal=2)
a_1 = torch.triu(x, diagonal=-1)
a_2 = torch.triu(x, diagonal=-2)
print("====>>>>>>>>>>>>>a0")
print(a0)
print("====>>>>>>>>>>>>>a1")
print(a1)
print("====>>>>>>>>>>>>>a2")
print(a2)
print("====>>>>>>>>>>>>>a_1")
print(a_1)
print("====>>>>>>>>>>>>>a_2")
print(a_2)
====>>>>>>>>>>>>>x
tensor([[[0.8693, 0.2737, 0.9769, 0.8450, 0.7284, 0.0411, 0.2800],
[0.3573, 0.7882, 0.0233, 0.3003, 0.8225, 0.8858, 0.8372],
[0.6543, 0.3388, 0.0010, 0.7694, 0.8262, 0.3854, 0.2707],
[0.3192, 0.3134, 0.0402, 0.6848, 0.1572, 0.2816, 0.0999],
[0.0573, 0.7577, 0.6373, 0.5226, 0.9810, 0.1912, 0.9347],
[0.1492, 0.9502, 0.3311, 0.2077, 0.7754, 0.9187, 0.6188],
[0.1916, 0.2040, 0.9452, 0.0766, 0.7309, 0.5191, 0.3115]]])
====>>>>>>>>>>>>>a0
tensor([[[0.8693, 0.2737, 0.9769, 0.8450, 0.7284, 0.0411, 0.2800],
[0.0000, 0.7882, 0.0233, 0.3003, 0.8225, 0.8858, 0.8372],
[0.0000, 0.0000, 0.0010, 0.7694, 0.8262, 0.3854, 0.2707],
[0.0000, 0.0000, 0.0000, 0.6848, 0.1572, 0.2816, 0.0999],
[0.0000, 0.0000, 0.0000, 0.0000, 0.9810, 0.1912, 0.9347],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9187, 0.6188],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3115]]])
====>>>>>>>>>>>>>a1
tensor([[[0.0000, 0.2737, 0.9769, 0.8450, 0.7284, 0.0411, 0.2800],
[0.0000, 0.0000, 0.0233, 0.3003, 0.8225, 0.8858, 0.8372],
[0.0000, 0.0000, 0.0000, 0.7694, 0.8262, 0.3854, 0.2707],
[0.0000, 0.0000, 0.0000, 0.0000, 0.1572, 0.2816, 0.0999],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1912, 0.9347],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6188],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
====>>>>>>>>>>>>>a2
tensor([[[0.0000, 0.0000, 0.9769, 0.8450, 0.7284, 0.0411, 0.2800],
[0.0000, 0.0000, 0.0000, 0.3003, 0.8225, 0.8858, 0.8372],
[0.0000, 0.0000, 0.0000, 0.0000, 0.8262, 0.3854, 0.2707],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2816, 0.0999],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9347],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
====>>>>>>>>>>>>>a_1
tensor([[[0.8693, 0.2737, 0.9769, 0.8450, 0.7284, 0.0411, 0.2800],
[0.3573, 0.7882, 0.0233, 0.3003, 0.8225, 0.8858, 0.8372],
[0.0000, 0.3388, 0.0010, 0.7694, 0.8262, 0.3854, 0.2707],
[0.0000, 0.0000, 0.0402, 0.6848, 0.1572, 0.2816, 0.0999],
[0.0000, 0.0000, 0.0000, 0.5226, 0.9810, 0.1912, 0.9347],
[0.0000, 0.0000, 0.0000, 0.0000, 0.7754, 0.9187, 0.6188],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5191, 0.3115]]])
====>>>>>>>>>>>>>a_2
tensor([[[0.8693, 0.2737, 0.9769, 0.8450, 0.7284, 0.0411, 0.2800],
[0.3573, 0.7882, 0.0233, 0.3003, 0.8225, 0.8858, 0.8372],
[0.6543, 0.3388, 0.0010, 0.7694, 0.8262, 0.3854, 0.2707],
[0.0000, 0.3134, 0.0402, 0.6848, 0.1572, 0.2816, 0.0999],
[0.0000, 0.0000, 0.6373, 0.5226, 0.9810, 0.1912, 0.9347],
[0.0000, 0.0000, 0.0000, 0.2077, 0.7754, 0.9187, 0.6188],
[0.0000, 0.0000, 0.0000, 0.0000, 0.7309, 0.5191, 0.3115]]])
Process finished with exit code 0
好记性不如烂键盘---点滴、积累、进步!