可微TopK算子
1.可微TopK算子
形式及推导
形式:前向计算如下所示,
注意
梯度推导:
令
则
难点在于如何计算
我们通过利用条件
因此,我们可以得到:
向量版本:如果令
其他细节:如何计算出
实现
# %% differentiable top-k function import torch from torch.func import vmap, grad from torch.autograd import Function import torch.nn as nn sigmoid = torch.sigmoid sigmoid_grad = vmap(vmap(grad(sigmoid))) class TopK(Function): @staticmethod def forward(ctx, xs, k): ts, ps = _find_ts(xs, k) ctx.save_for_backward(xs, ts) return ps @staticmethod def backward(ctx, grad_output): # Compute vjp, that is grad_output.T @ J. xs, ts = ctx.saved_tensors # Let v = sigmoid'(x + t) v = sigmoid_grad(xs + ts) s = v.sum(dim=1, keepdims=True) # Jacobian is -vv.T/s + diag(v) uv = grad_output * v t1 = -uv.sum(dim=1, keepdims=True) * v / s return t1 + uv, None @torch.no_grad() def _find_ts(xs, k): # (batch_size, input_dim) _, n = xs.shape assert 0 < k < n # Lo should be small enough that all sigmoids are in the 0 area. # Similarly Hi is large enough that all are in their 1 area. # (batch_size, 1) lo = -xs.max(dim=1, keepdims=True).values - 10 hi = -xs.min(dim=1, keepdims=True).values + 10 for iteration in range(64): mid = (hi + lo) / 2 subject = sigmoid(xs + mid).sum(dim=1) mask = subject < k lo[mask] = mid[mask] hi[~mask] = mid[~mask] ts = (lo + hi) / 2 return ts, sigmoid(xs + ts) def test_check(): topk = TopK.apply xs = torch.randn(2, 10) ps = topk(xs, 2) print(f"{xs=}") print(f"{ps=}") print(f"{ps.sum(dim=1)=}") from torch.autograd import gradcheck input = torch.randn(20, 10, dtype=torch.double, requires_grad=True) for k in range(1, 10): print(k, gradcheck(topk, (input, k), eps=1e-6, atol=1e-4)) def sgd_update(): topk = TopK.apply batch_size = 2 k = 2 tau = 10 xs = torch.randn(batch_size, 10, dtype=torch.double, requires_grad=True) target = torch.zeros_like(xs) target[torch.arange(batch_size), torch.argsort(xs, descending=True)[:, :k].T] = 1.0 print(f"{xs=}") print(f"{target=}") loss_fn = nn.MSELoss() learning_rate = 1 def fn(x): x = x * tau return topk(x, k) for iteration in range(1, 1000 + 1): ws = nn.Parameter(data=xs, requires_grad=True) ps = fn(ws) loss = loss_fn(ps.view(-1), target.view(-1)) loss.backward() xs = ws - learning_rate * ws.grad if iteration % 100 == 0: print(f"{iteration=} {fn(xs)=}") sgd_update()
相关资料
Differentiable top-k function - Stach Exchange
Softmax后传:寻找Top-K的光滑近似 - 科学空间
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 没有源码,如何修改代码逻辑?
· NetPad:一个.NET开源、跨平台的C#编辑器
· PowerShell开发游戏 · 打蜜蜂
· 凌晨三点救火实录:Java内存泄漏的七个神坑,你至少踩过三个!