CenterLoss
参考资料:https://blog.csdn.net/fxwfxw7037681/article/details/114440117
中心损失
centerloss是关注于类间距离,即我们认为同一类的对象距离其类别中心应该尽可能小。这种假设在聚类中是一个基本假设,注意这个假设与类间距无关!
令self.centers
标识所有类别的中心点,shape为(num_class, feat_dim)
,num_class为类别数,feat_dim是节点的坐标(特征向量)。令BiLSTM的输出为:out.shape = (seq_len, batch, feat_dim)
,我们将其转换为x.shape = (seq_len * batch, feat_dim)
,由如图所示的过程可以得到dist_map:
dist_map = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_class) + \
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_class, batch_size).t()
算距离为什么要两个“距离”相加?实际上真正的“距离”表征为:
dist_map.addbmm_(1, -2, x, self.centers.t())
那么它又为什么可以标识两个点间的距离呢?
解析:
上图中的c1, c2, ……, c8
分别标识8个类别到原点的距离平方;b1, b2, b3, b4
分别标识序列中每个节点对应的特征向量到原点的距离平方。上图最右边的map是两者之和,基于这些假设我们来说明上面的代码为什么是样本与所有中心点的距离:
-
首先
dist_map.addbmm_(1, -2, x, self.centers.t())
的计算公式为:\[dist\_map = \beta * dist\_map \quad + \quad \alpha \left ( X \times self.centers^{T} \right ) \]这里的\(\beta\)为1,\(\alpha\)为-2。也即:
\[dist\_map = dist\_map \quad -2 \left ( X \times self.centers^{T} \right ) \] -
假设有两个点\(\Alpha,\Beta\),那么他们之间安定距离平方为:
\[\left( \Alpha - \Beta \right)^{2} = \Alpha^{2} + \Beta^{2} -2\Alpha\Beta \]由公式(3)即可证明dist_map是距离的标识,只是我们这里是距离的平方。
代码
import torch
import torch.nn as nn
class CenterLoss(nn.Module):
"""
Reference:
Wen et al. A Discriminative Feature Learning Approach
for Deep Face Recognition. ECCV 2016.
https://blog.csdn.net/fxwfxw7037681/article/details/114440117
Attribute::
num_class: [int], 类别数量;
feat_dim: [int], 特征向量的维度;
"""
def __init__(self, num_class=10, feat_dim=2, use_gpu=True):
super(CenterLoss, self).__init__()
self.num_class = num_class
self.feat_dim = feat_dim
self.use_gpu = use_gpu
if self.use_gpu:
self.centers = nn.Parameter(torch.randn(self.num_class, self.feat_dim).cuda())
else:
self.centers = nn.Parameter(torch.randn(self.num_class, self.feat_dim))
nn.init.normal_(self.centers, mean=0, std=1)
def forward(self, x, labels):
"""
:param x: 特征图,shape为 (batch_size, feat_dim)
:param labels: GT label, shape 为 (batch_size)
:return:
"""
batch_size = x.size(0)
dist_map = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_class) + \
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_class, batch_size).t()
dist_map.addbmm_(1, -2, x, self.centers.t())
classes = torch.arange(self.num_class).long()
if self.use_gpu:
classes = classes.cuda()
labels = labels.unsqueeze(1).expand(batch_size, self.num_class)
mask = labels.eq(classes.expand(batch_size, self.num_class))
dist = dist_map * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
return loss
清澈的爱,只为中国