类感知MMD

Example:

 

import torch

def guassian_kernel(source, target, kernel_mul = 2.0, kernel_num = 5, fix_sigma=None):
    '''
    将源域数据和目标域数据转化为核矩阵,即上文中的K
    Params:
        source: 源域数据(n * len(x))
        target: 目标域数据(m * len(y))
        kernel_mul:
        kernel_num: 取不同高斯核的数量
        fix_sigma: 不同高斯核的sigma值
    Return:
        sum(kernel_val): 多个核矩阵之和
    '''
    n_samples = int(source.size()[0]) + int(target.size()[0])  # 求矩阵的行数
    total = torch.cat([source, target], dim=0) #将source和target按列方向合并  torch.Size([4, 3])
    # 将total复制(n+m)份
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))  #torch.Size([4, 4, 3])
    # 将total的每一行都复制成(n+m)行,即每个数据都扩展成(n+m)份
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    # 求任意两个数据之间的和,得到的矩阵中坐标(i,j)代表total中第i行数据和第j行数据之间的l2 distance(i==j时为0)
    L2_distance = ((total0 - total1) ** 2).sum(2)  #Size([4, 4])
    # 调整高斯核函数的sigma值
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
    #以fix_sigma为中值,以kernel_mul为倍数取kernel_num个bandwidth值(比如fix_sigma为1时,得到[0.25,0.5,1,2,4]
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
    # 高斯核函数的数学表达式
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    # 得到最终的核矩阵
    return sum(kernel_val)

def mmd_loss(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    '''
        计算源域数据和目标域数据的MMD距离
        Params:
            source: 源域数据(n * len(x))
            target: 目标域数据(m * len(y))
            kernel_mul:
            kernel_num: 取不同高斯核的数量
            fix_sigma: 不同高斯核的sigma值
        Return:
            loss: MMD loss
        '''
    batch_size = int(source.size()[0])  # 一般默认为源域和目标域的batchsize相同
    kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)

    # 将核矩阵分为4份
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY - YX)
    return loss


source = torch.tensor([[1,2,3],[2,3,66]])
target = torch.tensor([[1,1,2],[1,5,6]])
result = mmd_loss(source,target)
print(result) # tensor(1.8873)


sequence_output = sequence_output[:, 0, :]  #torch.Size([8, 768])
source_sequence_output = sequence_output[:source_target_split]  # torch.Size([4, 768])
source_labels = labels[:source_target_split]  # tensor([1, 0, 1, 1], device='cuda:2')
source_pos_output = source_sequence_output[(source_labels == 1).nonzero()[:, 0]]  # torch.Size([3, 768])
source_neg_output = source_sequence_output[(source_labels == 0).nonzero()[:, 0]]  # torch.Size([1, 768])

target_sequence_output = sequence_output[source_target_split:]
target_labels = labels[source_target_split:]
target_pos_output = target_sequence_output[(target_labels == 1).nonzero()[:, 0]]  # torch.Size([3, 768])
target_neg_output = target_sequence_output[(target_labels == 0).nonzero()[:, 0]]  # torch.Size([1, 768])

neg_output = sequence_output[(labels == 0).nonzero()[:, 0]]  # torch.Size([2, 768])
pos_output = sequence_output[(labels == 1).nonzero()[:, 0]]  # torch.Size([6, 768])

if len(source_neg_output) > 0 and len(target_neg_output) > 0:
    loss += alpha * self.mmd(source_neg_output, target_neg_output)
if len(source_pos_output) > 0 and len(target_pos_output) > 0:
    loss += alpha * self.mmd(source_pos_output, target_pos_output)
if len(pos_output) > 0 and len(neg_output) > 0:
    loss -= alpha * self.mmd(pos_output, neg_output)

 

posted @ 2023-04-18 21:14  多发Paper哈  阅读(17)  评论(0编辑  收藏  举报
Live2D