终于知道centernet的sigmoid为什么需要加clamp了
终于知道centernet的sigmoid为什么需要加clamp了, 由于我工程训练报错debug才知道,在计算focalloss的时候,
model_out = model_out.sigmoid()
pos_loss = -torch.log(model_out) * torch.pow(1 - model_out, self.alpha) * pos_inds
这里当网络model_out的值很小,1e-22次方接近0的值,然后经过torch.log的时候就会nan,所以把model_out的sigmoid加限制一下就可以。
改成如下:
model_out = model_out.sigmoid()
model_out = torch.clamp(model_out, min=1e-4, max=1 - 1e-4)
def _sigmoid(x):
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
好记性不如烂键盘---点滴、积累、进步!