hello
点击查看代码
def forward(self, output_logits, target, extra_info=None): if extra_info is None: return self.base_loss(output_logits, target) # output_logits indicates the final prediction loss = 0 temperature_mean = 1 temperature = 1 # Obtain logits from each expert epoch = extra_info['epoch'] num = int(target.shape[0] / 2) expert1_logits = extra_info['logits'][0] + torch.log(torch.pow(self.prior, -0.5) + 1e-9) #head expert2_logits = extra_info['logits'][1] + torch.log(torch.pow(self.prior, 1) + 1e-9) #medium expert3_logits = extra_info['logits'][2] + torch.log(torch.pow(self.prior, 2.5) + 1e-9) #few teacher_expert1_logits = expert1_logits[:num, :] # view1 student_expert1_logits = expert1_logits[num:, :] # view2 teacher_expert2_logits = expert2_logits[:num, :] # view1 student_expert2_logits = expert2_logits[num:, :] # view2 teacher_expert3_logits = expert3_logits[:num, :] # view1 student_expert3_logits = expert3_logits[num:, :] # view2 teacher_expert1_softmax = F.softmax((teacher_expert1_logits) / temperature, dim=1).detach() student_expert1_softmax = F.log_softmax(student_expert1_logits / temperature, dim=1) teacher_expert2_softmax = F.softmax((teacher_expert2_logits) / temperature, dim=1).detach() student_expert2_softmax = F.log_softmax(student_expert2_logits / temperature, dim=1) teacher_expert3_softmax = F.softmax((teacher_expert3_logits) / temperature, dim=1).detach() student_expert3_softmax = F.log_softmax(student_expert3_logits / temperature, dim=1) teacher1_max, teacher1_index = torch.max(F.softmax((teacher_expert1_logits), dim=1).detach(), dim=1) student1_max, student1_index = torch.max(F.softmax((student_expert1_logits), dim=1).detach(), dim=1) teacher2_max, teacher2_index = torch.max(F.softmax((teacher_expert2_logits), dim=1).detach(), dim=1) student2_max, student2_index = torch.max(F.softmax((student_expert2_logits), dim=1).detach(), dim=1) teacher3_max, teacher3_index = torch.max(F.softmax((teacher_expert3_logits), dim=1).detach(), dim=1) student3_max, student3_index = torch.max(F.softmax((student_expert3_logits), dim=1).detach(), dim=1) # distillation partial_target = target[:num] kl_loss = 0 if torch.sum((teacher1_index == partial_target)) > 0: kl_loss = kl_loss + F.kl_div(student_expert1_softmax[(teacher1_index == partial_target)], teacher_expert1_softmax[(teacher1_index == partial_target)], reduction='batchmean') * (temperature ** 2) if torch.sum((teacher2_index == partial_target)) > 0: kl_loss = kl_loss + F.kl_div(student_expert2_softmax[(teacher2_index == partial_target)], teacher_expert2_softmax[(teacher2_index == partial_target)], reduction='batchmean') * (temperature ** 2) if torch.sum((teacher3_index == partial_target)) > 0: kl_loss = kl_loss + F.kl_div(student_expert3_softmax[(teacher3_index == partial_target)], teacher_expert3_softmax[(teacher3_index == partial_target)], reduction='batchmean') * (temperature ** 2) loss = loss + 0.6 * kl_loss * min(extra_info['epoch'] / self.warmup, 1.0) # expert 1 loss += self.base_loss(expert1_logits, target) # expert 2 loss += self.base_loss(expert2_logits, target) # expert 3 loss += self.base_loss(expert3_logits, target) return loss
本文作者:太好了还有脑子可以用
本文链接:https://www.cnblogs.com/ZarkY/p/18343493
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步