论文信息

论文标题:Deep CORAL: Correlation Alignment for Deep Domain Adaptation
论文作者:Baochen Sun, Kate Saenko
论文来源:ECCV 2016
论文地址:download 
论文代码:download
引用次数:2203

1 介绍

  解决的问题:深度神经网络可以在大规模的标注数据中学校到特征,但是输入数据分布不同的时候泛化不是很好。因此提出了域适应来弥补性能。本文针对目标域没有标注数据情况,对 CORAL 进行了改进。 

  CORAL 方法用线性变换方法将源域和目标域分布的二阶统计特征进行对齐,对于无监督域适应效果很好。问题出在依赖的是线性变换,而且不是端到端训练。训练分为两步,首先提取特征,应用变换,然后训练 SVM 分类。

2 方法

  模型框架:

  

  设源域训练样本 Ds={xi},xRd ,标签 Ls=yi,i{1,,L}。无标签的目标域数据  DT={ui},uRd  。其中,d  为网络  fc8  的输出维度。令  DSijDTij  分别表示第  i  个源域、目标域样本的第  j  维特征。CS(CT)  表示特征协方差矩阵。

  CORAL loss 是源域和目标域特征的 协方差距离:

    CORAL=14d2CSCTF2(1)

  其中,

    CS=1nS1(DSDS1nS(1DS)(1DS))(2)

    CT=1nT1(DTDT1nT(1DT)(1DT))(3)

  其中,1 代表全 1 的列向量。

  上述公式的梯度表达式:

    CORALDSij=1d2(nS1)((DS1nS(1DS)1)(CSCT))ij(4)

    CORALDTij=1d2(nT1)((DT1nT(1DT)1)(CSCT))ij(5)

  损失函数:

    =CLASS.+i=1tλiCORAL(6)

 

三个版本的 CORAL Loss

复制代码
def CORAL(source, target):
    # source.shape = torch.Size([200, 31])
    # target.shape = torch.Size([56, 31])
    d = source.data.shape[1]  #31

    # source covariance  计算协方差矩阵
    xm = torch.mean(source, 0, keepdim=True) - source  #torch.Size([200, 31])
    xc = xm.t() @ xm  #torch.Size([31, 31])            #torch.Size([31, 31])

    # target covariance
    xmt = torch.mean(target, 0, keepdim=True) - target
    xct = xmt.t() @ xmt      #torch.Size([31, 31])

    # frobenius norm between source and target
    loss = torch.mean(torch.mul((xc - xct), (xc - xct)))
    loss = loss/(4*d*d)

    return loss
复制代码
复制代码
    def CORAL1(self, source, target):
        device = source.device
        d = source.size(1)
        ns, nt = source.size(0), target.size(0)

        # source covariance
        tmp_s = torch.ones((1, ns)).to(device) @ source
        cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)

        # target covariance
        tmp_t = torch.ones((1, nt)).to(device) @ target
        ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)

        # frobenius norm
        loss = (cs - ct).pow(2).sum().sqrt()
        loss = loss / (4 * d * d)
        return loss
复制代码
复制代码
    def CORAL2(self, x, y):
        # x.shape = torch.Size([32, 2048])
        # y.shape = torch.Size([32, 2048])
        mean_x = x.mean(0, keepdim=True)  # torch.Size([1, 2048])
        mean_y = y.mean(0, keepdim=True)  # torch.Size([1, 2048])
        cent_x = x - mean_x
        cent_y = y - mean_y
        cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
        cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)

        mean_diff = (mean_x - mean_y).pow(2).mean()
        cova_diff = (cova_x - cova_y).pow(2).mean()

        return mean_diff + cova_diff
复制代码

 

 

posted @   别关注我了,私信我吧  阅读(744)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
Live2D
点击右上角即可分享
微信分享提示