类感知MMD
Example:
import torch def guassian_kernel(source, target, kernel_mul = 2.0, kernel_num = 5, fix_sigma=None): ''' 将源域数据和目标域数据转化为核矩阵,即上文中的K Params: source: 源域数据(n * len(x)) target: 目标域数据(m * len(y)) kernel_mul: kernel_num: 取不同高斯核的数量 fix_sigma: 不同高斯核的sigma值 Return: sum(kernel_val): 多个核矩阵之和 ''' n_samples = int(source.size()[0]) + int(target.size()[0]) # 求矩阵的行数 total = torch.cat([source, target], dim=0) #将source和target按列方向合并 torch.Size([4, 3]) # 将total复制(n+m)份 total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) #torch.Size([4, 4, 3]) # 将total的每一行都复制成(n+m)行,即每个数据都扩展成(n+m)份 total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) # 求任意两个数据之间的和,得到的矩阵中坐标(i,j)代表total中第i行数据和第j行数据之间的l2 distance(i==j时为0) L2_distance = ((total0 - total1) ** 2).sum(2) #Size([4, 4]) # 调整高斯核函数的sigma值 if fix_sigma: bandwidth = fix_sigma else: bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) #以fix_sigma为中值,以kernel_mul为倍数取kernel_num个bandwidth值(比如fix_sigma为1时,得到[0.25,0.5,1,2,4] bandwidth /= kernel_mul ** (kernel_num // 2) bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] # 高斯核函数的数学表达式 kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] # 得到最终的核矩阵 return sum(kernel_val) def mmd_loss(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): ''' 计算源域数据和目标域数据的MMD距离 Params: source: 源域数据(n * len(x)) target: 目标域数据(m * len(y)) kernel_mul: kernel_num: 取不同高斯核的数量 fix_sigma: 不同高斯核的sigma值 Return: loss: MMD loss ''' batch_size = int(source.size()[0]) # 一般默认为源域和目标域的batchsize相同 kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) # 将核矩阵分为4份 XX = kernels[:batch_size, :batch_size] YY = kernels[batch_size:, batch_size:] XY = kernels[:batch_size, batch_size:] YX = kernels[batch_size:, :batch_size] loss = torch.mean(XX + YY - XY - YX) return loss source = torch.tensor([[1,2,3],[2,3,66]]) target = torch.tensor([[1,1,2],[1,5,6]]) result = mmd_loss(source,target) print(result) # tensor(1.8873) sequence_output = sequence_output[:, 0, :] #torch.Size([8, 768]) source_sequence_output = sequence_output[:source_target_split] # torch.Size([4, 768]) source_labels = labels[:source_target_split] # tensor([1, 0, 1, 1], device='cuda:2') source_pos_output = source_sequence_output[(source_labels == 1).nonzero()[:, 0]] # torch.Size([3, 768]) source_neg_output = source_sequence_output[(source_labels == 0).nonzero()[:, 0]] # torch.Size([1, 768]) target_sequence_output = sequence_output[source_target_split:] target_labels = labels[source_target_split:] target_pos_output = target_sequence_output[(target_labels == 1).nonzero()[:, 0]] # torch.Size([3, 768]) target_neg_output = target_sequence_output[(target_labels == 0).nonzero()[:, 0]] # torch.Size([1, 768]) neg_output = sequence_output[(labels == 0).nonzero()[:, 0]] # torch.Size([2, 768]) pos_output = sequence_output[(labels == 1).nonzero()[:, 0]] # torch.Size([6, 768]) if len(source_neg_output) > 0 and len(target_neg_output) > 0: loss += alpha * self.mmd(source_neg_output, target_neg_output) if len(source_pos_output) > 0 and len(target_pos_output) > 0: loss += alpha * self.mmd(source_pos_output, target_pos_output) if len(pos_output) > 0 and len(neg_output) > 0: loss -= alpha * self.mmd(pos_output, neg_output)
因上求缘,果上努力~~~~ 作者:多发Paper哈,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/articles/17331162.html