ruijiege

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::
def smooth_l1_loss_modify(predict, target, mask, sigma=3):
    # predict: bx2x96x128
    # target : bx2x96x128
    # mask   : bx2x96x128
    num_object = mask.sum().item() / mask.size(1)
    sigma2 = sigma * sigma
    diff = predict[mask] - target[mask]
    diff_abs = diff.abs()
    near = (diff_abs < 1 / sigma2).float()
    far = 1 - near
    return (near * 0.5 * sigma2 * torch.pow(diff, 2) + far * (diff_abs - 0.5 / sigma2)).sum() / num_object
    
def l2_loss_modify(predict, target, mask):
    # predict: bx2x96x128
    # target : bx2x96x128
    # mask   : bx2x96x128
    num_object = mask.sum().item() / mask.size(1)
    if num_object == 0 : num_object = 1
    masked_predict = predict[mask]
    masked_target = target[mask]
    return torch.pow(masked_predict - masked_target, 2).sum() / num_object

def l1_loss_modify(predict, target, mask):
    # predict: bx2x96x128
    # target : bx2x96x128
    # mask   : bx2x96x128
    num_object = mask.sum().item() / mask.size(1)
    if num_object == 0 : num_object = 1
    masked_predict = predict[mask]
    masked_target = target[mask]
    return torch.abs(masked_predict - masked_target).sum() / num_object

 

posted on 2022-11-02 17:52  哦哟这个怎么搞  阅读(36)  评论(0编辑  收藏  举报