可微TopK算子

形式及推导

形式:前向计算如下所示,

TopK(x,k)=σ(x+Δ(x,k))

注意Δ()满足限制条件Δ(x,k)=k,并且σ(x)=11+exp{x}


梯度推导:
f(x,k)=σ(x+Δ(x,k))

df(x,k)idxj=dσ(xi+Δ(x,k))dxj=σ(xi+Δ(x))(Ii=j+dΔ(x)dxj)

难点在于如何计算dΔ(x)dxj

我们通过利用条件Δ(x)=k来计算上述导数:

dkdxj=0=iσ(xi+Δ(x))(Ii=j+dΔ(x)dxj)=σ(xj+Δ(x))+dΔ(x)dxjiσ(xi+Δ(x))

因此,我们可以得到:

dΔ(x)dxj=σ(xj+Δ(x))iσ(xi+Δ(x))

向量版本:如果令v=σ(x+Δ(x)),则雅可比矩阵为

JTopK(x)=diag(v)vvv1

其他细节:如何计算出Δ(x)=k?可以通过二分法快速找到该函数的合适值。

实现

# %% differentiable top-k function
import torch
from torch.func import vmap, grad
from torch.autograd import Function
import torch.nn as nn
sigmoid = torch.sigmoid
sigmoid_grad = vmap(vmap(grad(sigmoid)))
class TopK(Function):
@staticmethod
def forward(ctx, xs, k):
ts, ps = _find_ts(xs, k)
ctx.save_for_backward(xs, ts)
return ps
@staticmethod
def backward(ctx, grad_output):
# Compute vjp, that is grad_output.T @ J.
xs, ts = ctx.saved_tensors
# Let v = sigmoid'(x + t)
v = sigmoid_grad(xs + ts)
s = v.sum(dim=1, keepdims=True)
# Jacobian is -vv.T/s + diag(v)
uv = grad_output * v
t1 = -uv.sum(dim=1, keepdims=True) * v / s
return t1 + uv, None
@torch.no_grad()
def _find_ts(xs, k):
# (batch_size, input_dim)
_, n = xs.shape
assert 0 < k < n
# Lo should be small enough that all sigmoids are in the 0 area.
# Similarly Hi is large enough that all are in their 1 area.
# (batch_size, 1)
lo = -xs.max(dim=1, keepdims=True).values - 10
hi = -xs.min(dim=1, keepdims=True).values + 10
for iteration in range(64):
mid = (hi + lo) / 2
subject = sigmoid(xs + mid).sum(dim=1)
mask = subject < k
lo[mask] = mid[mask]
hi[~mask] = mid[~mask]
ts = (lo + hi) / 2
return ts, sigmoid(xs + ts)
def test_check():
topk = TopK.apply
xs = torch.randn(2, 10)
ps = topk(xs, 2)
print(f"{xs=}")
print(f"{ps=}")
print(f"{ps.sum(dim=1)=}")
from torch.autograd import gradcheck
input = torch.randn(20, 10, dtype=torch.double, requires_grad=True)
for k in range(1, 10):
print(k, gradcheck(topk, (input, k), eps=1e-6, atol=1e-4))
def sgd_update():
topk = TopK.apply
batch_size = 2
k = 2
tau = 10
xs = torch.randn(batch_size, 10, dtype=torch.double, requires_grad=True)
target = torch.zeros_like(xs)
target[torch.arange(batch_size), torch.argsort(xs, descending=True)[:, :k].T] = 1.0
print(f"{xs=}")
print(f"{target=}")
loss_fn = nn.MSELoss()
learning_rate = 1
def fn(x):
x = x * tau
return topk(x, k)
for iteration in range(1, 1000 + 1):
ws = nn.Parameter(data=xs, requires_grad=True)
ps = fn(ws)
loss = loss_fn(ps.view(-1), target.view(-1))
loss.backward()
xs = ws - learning_rate * ws.grad
if iteration % 100 == 0:
print(f"{iteration=} {fn(xs)=}")
sgd_update()

相关资料

Differentiable top-k function - Stach Exchange
Softmax后传:寻找Top-K的光滑近似 - 科学空间

posted @   WrRan  阅读(116)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 没有源码,如何修改代码逻辑?
· NetPad:一个.NET开源、跨平台的C#编辑器
· PowerShell开发游戏 · 打蜜蜂
· 凌晨三点救火实录:Java内存泄漏的七个神坑,你至少踩过三个!
点击右上角即可分享
微信分享提示