点击查看代码
def forward(self, output_logits, target, extra_info=None):
if extra_info is None:
return self.base_loss(output_logits, target) # output_logits indicates the final prediction
loss = 0
temperature_mean = 1
temperature = 1
# Obtain logits from each expert
epoch = extra_info['epoch']
num = int(target.shape[0] / 2)
expert1_logits = extra_info['logits'][0] + torch.log(torch.pow(self.prior, -0.5) + 1e-9) #head
expert2_logits = extra_info['logits'][1] + torch.log(torch.pow(self.prior, 1) + 1e-9) #medium
expert3_logits = extra_info['logits'][2] + torch.log(torch.pow(self.prior, 2.5) + 1e-9) #few
teacher_expert1_logits = expert1_logits[:num, :] # view1
student_expert1_logits = expert1_logits[num:, :] # view2
teacher_expert2_logits = expert2_logits[:num, :] # view1
student_expert2_logits = expert2_logits[num:, :] # view2
teacher_expert3_logits = expert3_logits[:num, :] # view1
student_expert3_logits = expert3_logits[num:, :] # view2
teacher_expert1_softmax = F.softmax((teacher_expert1_logits) / temperature, dim=1).detach()
student_expert1_softmax = F.log_softmax(student_expert1_logits / temperature, dim=1)
teacher_expert2_softmax = F.softmax((teacher_expert2_logits) / temperature, dim=1).detach()
student_expert2_softmax = F.log_softmax(student_expert2_logits / temperature, dim=1)
teacher_expert3_softmax = F.softmax((teacher_expert3_logits) / temperature, dim=1).detach()
student_expert3_softmax = F.log_softmax(student_expert3_logits / temperature, dim=1)
teacher1_max, teacher1_index = torch.max(F.softmax((teacher_expert1_logits), dim=1).detach(), dim=1)
student1_max, student1_index = torch.max(F.softmax((student_expert1_logits), dim=1).detach(), dim=1)
teacher2_max, teacher2_index = torch.max(F.softmax((teacher_expert2_logits), dim=1).detach(), dim=1)
student2_max, student2_index = torch.max(F.softmax((student_expert2_logits), dim=1).detach(), dim=1)
teacher3_max, teacher3_index = torch.max(F.softmax((teacher_expert3_logits), dim=1).detach(), dim=1)
student3_max, student3_index = torch.max(F.softmax((student_expert3_logits), dim=1).detach(), dim=1)
# distillation
partial_target = target[:num]
kl_loss = 0
if torch.sum((teacher1_index == partial_target)) > 0:
kl_loss = kl_loss + F.kl_div(student_expert1_softmax[(teacher1_index == partial_target)],
teacher_expert1_softmax[(teacher1_index == partial_target)],
reduction='batchmean') * (temperature ** 2)
if torch.sum((teacher2_index == partial_target)) > 0:
kl_loss = kl_loss + F.kl_div(student_expert2_softmax[(teacher2_index == partial_target)],
teacher_expert2_softmax[(teacher2_index == partial_target)],
reduction='batchmean') * (temperature ** 2)
if torch.sum((teacher3_index == partial_target)) > 0:
kl_loss = kl_loss + F.kl_div(student_expert3_softmax[(teacher3_index == partial_target)],
teacher_expert3_softmax[(teacher3_index == partial_target)],
reduction='batchmean') * (temperature ** 2)
loss = loss + 0.6 * kl_loss * min(extra_info['epoch'] / self.warmup, 1.0)
# expert 1
loss += self.base_loss(expert1_logits, target)
# expert 2
loss += self.base_loss(expert2_logits, target)
# expert 3
loss += self.base_loss(expert3_logits, target)
return loss