终于知道centernet的sigmoid为什么需要加clamp了

终于知道centernet的sigmoid为什么需要加clamp了, 由于我工程训练报错debug才知道,在计算focalloss的时候,

model_out = model_out.sigmoid()
pos_loss = -torch.log(model_out) * torch.pow(1 - model_out, self.alpha) * pos_inds

这里当网络model_out的值很小,1e-22次方接近0的值,然后经过torch.log的时候就会nan,所以把model_out的sigmoid加限制一下就可以。

改成如下:

model_out = model_out.sigmoid()
model_out = torch.clamp(model_out, min=1e-4, max=1 - 1e-4)
def _sigmoid(x):
  y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
posted @ 2023-04-11 10:07  无左无右  阅读(40)  评论(0编辑  收藏  举报