[pytorch]pytorch 将x按阈值条件进行变换/裁剪/映射

pytorch 将x按阈值条件进行变换/裁剪/映射

0. 场景

在进行深度学习模型裁剪时,一个很显然的需求是将小于某个阈值的值全都设置为0,比如设置成阈值为0.5, 也即x[x<=0.5] = 0
借助mask,很容易实现上述的需求,但推广起来,比如将对应位置设置为对应的fun,也即x[x<=0.5] = f1(x), x[x<=1.0 and x>0.5] = f2(x) 设计起来可能就没那么容易了。

我们便遇到一个场景,要对不同的x实现不同的放缩,即x[x<=0.5] = w1 * x, x[x<=1.0 and x>0.5] = w2 * x, ...

1. 方法

首先需要将x的值进行映射,即将x映射到对应的id上,然后通过索引的方式找到对应的weight值,最后再相乘即可,下面是具体的代码实现:

def magic_func(x):
    k = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]
    k = torch.tensor(k)
    
    idx = torch.tensor(x+8, dtype=torch.long)
    idx = torch.clamp(idx, min=0, max=15)
    weight_k = k[idx]
    
    y = weight_k.mul(x)
    
    return y

通过这种方法,可以将x在[-8, -7]区间的置为1x,[-7, -6]的区间的置为2x。

posted @ 2022-05-15 15:09  wildkid1024  阅读(436)  评论(0编辑  收藏  举报