pytorch04——自定义损失函数

pytroch在torch.nn模块中本来就为我们提供了许多常用的损失函数,比如MSELoss,L1Loss,BCELoss.........但是在科研中还有实际一些运用场景中,我们需要通过自定义损失函数的方式来实现一些损失函数。

1.以函数的方式自定义损失函数

def my_loss(output,input):
    loss = torch.mean((output - target) ** 2)
return loss

2.以类的方式进行定义

虽然以函数定义的方式很简单,但是以类方式定义更加常用,在以类的方式定义损失函数时,我们如果看每一个损失函数的继承关系,我们就可以发现Loss函数部分继承自_loss,部分继承自_weightedLoss,而_WeightedLoss继承自_loss._loss继承自nn.Module。因此我们可以将 以类定义的损失函数当作神经网络中的一层来对待,因此我们自定义的损失函数类就需要继承自nn.Module类
1.例如在分割领域常见的损失函数,DiceLoss

class DiceLoss(nn.Module):
    def __init__(self,weight=None,size_average=True):
        super(DiceLoss,self).__init__()
        
	def forward(self,inputs,targets,smooth=1)
        inputs = F.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()                   
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        return 1 - dice

# 使用方法    
criterion = DiceLoss()
loss = criterion(input,targets)

2.DiceBCELoss

class DiceBCELoss(nn.Module):
    def __init__(self,weight=None,size_average = True):
        super(DiceBCELoss, self).__init__()
    def forward(self,inputs,targets,smooth=1):
        inputs = F.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs*targets).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        return Dice_BCE

3.IouLoss

class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()
    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection
        IoU = (intersection + smooth) / (union + smooth)
        return 1 - IoU

4.FocalLoss

class FocalLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalLoss, self).__init__()
    def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
        inputs = F.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = alpha * (1 - BCE_EXP) ** gamma * BCE
        return focal_loss

总结:

自定函数可以通过函数和类两种方式进行实现,不过在实际运用中用类更多,我们全程使用PyTorch提供的张量计算接口,这样集不需要我们去实现自动求导功能,并且可以直接进行调用cuda

posted @ 2022-03-18 08:52  TCcjx  阅读(1077)  评论(0编辑  收藏  举报