论文信息

论文标题:Joint domain alignment and discriminative feature learning for unsupervised deep domain adaptation
论文作者:Chao Chen , Zhihong Chen , Boyuan Jiang , Xinyu Jin
论文来源:AAAI 2019
论文地址:download 
论文代码:download
引用次数:175

1 Introduction

  近年来,大多数工作集中于减少不同领域之间的分布差异来学习共享的特征表示,由于所有的域对齐方法只能减少而不能消除域偏移,因此分布在簇边缘或远离相应类中心的目标域样本很容易被从源域学习到的超平面误分类。为缓解这一问题,提出联合域对齐和判别特征学习,有利于 域对齐 和 分类。具体提出了一种基于实例的判别特征学习方法和一种基于中心的判别特征学习方法,两者均保证了域不变特征具有更好的类内紧凑性和类间可分性。大量的实验表明,在共享特征空间中学习鉴别特征可以显著提高性能。
  域适应,关注如何从源域的大量标记样本和目标域有限或没有标记的目标样本学习分类,可以分为如下三种方法:

    • feature-based domain adaptation
    • instance-based domain adaptation
    • classifier-based domain adaptation

2 Method

  总体框架如下:

    

2.1 Problem statement

  In this work, following the settings of unsupervised domain adaptation, we define the labeled source data as  Ds={Xs,Ys}={(xis,yis)}i=1ns  and define the unlabeled target data as  Dt={Xt}={xit}i=1nt , where  xs  and  xt  have the same dimension  xs(t)Rd . Let  Θ  denotes the shared parameters to be learned.  HsRb×L  and  HtRb×L  denote the learned deep features in the bottleneck layer regard to the source stream and target stream, respectively.  b  indicates the batch size during the training stage and  L  is the number of hidden neurons in the bottleneck layer. Then, the networks can be trained by minimizing the following loss function.

    L(ΘXs,Ys,Xt)=Ls+λ1Lc+λ2Ld(1)Ls=1nsi=1nsc(Θxis,yis)(2)Lc=CORAL(Hs,Ht)(3)Ld=Jd(ΘXs,Ys)(4)

  其中

    • Ls 代表源域分类损失;
    • Lc=CORAL(Hs,Ht) 表示通过相关性对齐度量的域差异损失;
    • Jd(ΘXs,Ys) 代表鉴别损失,保证了域不变特征具有更好的类内紧致性和类间可分性;

2.2 Correlation Alignment (CORAL)

  为学习域不变特征,通过对齐源特征和目标特征的协方差来减少域差异。域差异损失如下:

    Lc=CORAL(Hs,Ht)=14L2Cov(Hs)Cov(Ht)F2(5)

  其中:

    • F2 为矩阵 Frobenius 范数;  
    • Cov(Hs)Cov(Ht) 表示 bottleneck layer 中源特征和目标特征的协方差矩阵;  
      • Cov(Hs)=HsJbHs
      • Cov(Ht)=HtJbHt
        • Jb=Ib1b1n1nTscentralized matrix
        • 1bRb1 列向量;
        • b 是批大小;

  注意,训练过程是通过小批量 SGD 实现的,因此,在每次迭代中,只有一批训练样本被对齐。

2.3 Discriminative Feature Learning

  为学习更具判别性的特征,提出两种判别特征学习方法:基于实例的判别特征学习基于中心的判别特征学习

  注意,整个训练阶段都是基于小批量 SGD 的。因此,下面给出的鉴别损失是基于一批样本的。

2.3.1 Instance-Based Discriminative Loss

  基于实例的判别特征学习的动机是:同一类的样本在特征空间中应该尽可能地接近,不同类的样本之间应有较大距离。

  基于实例的判别损失 LdI 可以表示为:

    JdI(his,hjs)={max(0,hishjs2m1)2Cij=1max(0,m2hishjs2)2Cij=0(6)
    LdI=i,j=1nsJdI(his,hjs)(7)

  其中:

    • Hs=[h1s;h2s;;hbs]
    • Cij=1 表示 his 和 hjs 来自同一个类,Cij=0 表示 hishjs 来自不同的类;
    • m2 大于 m1

  从 Eq.6Eq.7 中可以看出,判别损失会使类内样本之间的距离不超过 m1,而类间样本之间的距离至少 m2

  为简洁起见,将深度特征 Hs 的成对距离表示为 DHRb×b,其中 DijH=hishjs2。设 LRb×b 表示指示器矩阵,如果第 i 个样本和第 j 个样本来自同一个类,则表示 Lij=1,如果它们来自不同的类,则表示 Lij=0。然后,基于实例的判别损失可简化为:

    LdI=αmax(0,DHm1)2Lsum +max(0,m2DH)2(1L)sum(8)

2.3.2 Center-Based Discriminative Loss

  基于实例的鉴别损失 需要计算样本之间的成对距离,计算成本较高。受 Center Loss 惩罚每个样本到相应类中心的距离的启发,本文提出基于中心的判别特征学习:

    LdC=βi=1nsmax(0,hiscyi22m1)+i,j=1,ijcmax(0,m2cicj22)(9)

  其中:

    • β 为权衡参数;
    • m1m2 为两个约束边距 (m1<m2)
    • cyiRd 表示第 yi 类的质心,yi{1,2,,c}c 表示类数;  

  理想情况下,类中心 ci 应通过平均所有样本的深层特征来计算。但由于本文是基于小批量进行更新的,因此很难用整个训练集对深度特征进行平均。在此,本文做了一个必要的修改,对于 Eq.9 中判别损失的第二项,用于度量类间可分性的 cicj 是通过对当前一批深度特征进行平均来近似计算的,称之为 “批类中心” 。相反,用于测量类内紧致性的 cyi 应该更准确,也更接近 “全局类中心”。因此,在每次迭代中更新 cyi

    Δcj=i=1bδ(yi=j)(cjhis)1+i=1bδ(yi=j)(10)cjt+1=cjtγΔcjt(11)

  “全局类中心” 在第一次迭代中被初始化为“批类中心”,在每次迭代中通过 Eq.10Eq.11 进行更新,其中 γ 是更新“全局类中心”的学习速率。为简洁起见,Eq.9 可以简化为

    LdC=βmax(0,Hcm1)sum +max(0,m2Dc)Msum 

  其中:

    • Hc=[h1c;h2c;;hbc]hic=hiscyi22 表示第 i 个样本深层特征与其对应的中心 cyi 之间的距离;
    • DcRc×c 表示“批类中心”的成对距离,即 Dijc=cicj22

  不同于 Center Loss ,它只考虑类内的紧致性,本文不仅惩罚了深度特征与其相应的类中心之间的距离,而且在不同类别的中心之间加强了较大的边际。

2.4 Training

  所提出的 Instance-Based joint discriminative domain adaptation (JDDA-I)Center-Based joint discriminative domain adaptation (JDDA-C) 都可以通过小批量SGD轻松实现。对于 JDDA-I,总损失为  L=Ls+λ1Lc+λ2ILdILc 代表源域的分类损失。因此,参数 Θ 可以通过标准的反向传播直接更新

    Θt+1=Θtη(Ls+λ1Lc+λ2ILdI)xi(13) 

  由于 “global class center” 不能通过一批样本来计算,因此 JDDA-C 必须在每次迭代中同时更新 Θ 和“全局类中心”:
    Θt+1=Θtη(Ls+λ1Lc+λ2CLdC)xi
    cjt+1=cjtγΔcjtj=1,2,,c(14) 

3 Experiments

 

 

 

====

posted @   别关注我了,私信我吧  阅读(380)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
历史上的今天:
2022-01-12 论文解读(DEC)《Unsupervised Deep Embedding for Clustering Analysis》
Live2D
点击右上角即可分享
微信分享提示