torch.unique

写代码的时候想把一个张量X中的最后一个维度进行类似集合那样的操作,于是网上找到了torch.unique这个方法(官方文档

torch.unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) → Tuple[Tensor, Tensor, Tensor]

其中参数sortedreturn_inversereturn_counts网上有很多介绍(可以参考这篇),这里不再赘述,本博文主要谈谈对dim参数的理解
实际上参数dim就是用来指定划分子张量的维度,unique就是先按指定维度将原张量X划分多个子张量x0,x1,...,xn,再在这些子张量中剔除重复的子张量

举例子:

X = torch.tensor([[3, 3, 3, 2, 4],
[0, 3, 0, 2, 4]])
# dim=0的情况
uni_X = torch.unique(X, dim=0)
print(uni_X)
# tensor([[0, 3, 0, 2, 4],
# [3, 3, 3, 2, 4]])
# dim=1的情况
uni_X = torch.unique(X, dim=1)
print(uni_X)
# tensor([[2, 3, 3, 4],
# [2, 0, 3, 4]])

以上例子,在dim=0情况时,按以下步骤理解:

  1. 先按第0维(即按行)把原张量X划分为2个子张量:x0=[3,3,3,2,4],x1=[0,3,0,2,4]
  2. 由于x0x1即没有重复的,所以结果还是由x0,x1组成
  3. 别忘了默认参数sorted=True,即对x0x1进行字典序升序后再返回,这里按字典序有x1<x0,所以返回张量[x1;x0]=[[0,3,0,2,4];[3,3,3,2,4]](按第0维划分的就按第0维拼回去)

在dim=1情况时,按以下步骤理解:

  1. 先按第1维(即按列)把原张量X划分为5个子张量:x0=[3,0],x1=[3,3],x2=[3,0],x3=[2,2],x4=[4,4]
  2. 可以发现只有x0=x2即只有这对重复,任意剔除其中一个(这里剔除了x2),则结果由x0,x1,x3,x4组成
  3. sorted=True影响,对x0,x1,x3,x4进行字典序升序后再返回,这里按字典序有x3<x0<x2<x4,所以返回张量[x3,x0,x2,x4]=[[2,3,3,4];[2,0,3,4]](按第1维划分的就按第1维拼回去)

最后,因为unique只能返回张量形式即要求结果对齐,所以要实现开头的需求,只能在指定维度内分别使用unique
即:

X = torch.tensor([[[3, 4, 3, 0],
[3, 0, 3, 3],
[4, 3, 0, 0]],
[[0, 2, 2, 3],
[4, 2, 1, 0],
[2, 4, 0, 3]]])
uni_X = [[torch.unique(x_ij) for x_ij in x_i] for x_i in X]
print(uni_X)
# [[tensor([0, 3, 4]), tensor([0, 3]), tensor([0, 3, 4])],
# [tensor([0, 2, 3]), tensor([0, 1, 2, 4]), tensor([0, 2, 3, 4])]]
posted @   kksk43  阅读(329)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
特效
黑夜
侧边栏隐藏
点击右上角即可分享
微信分享提示