论文信息

论文标题:Domain-Adversarial Training of Neural Networks
论文作者:Yaroslav Ganin, Evgeniya Ustinova, Hana Ajakan, Pascal Germain
论文来源:JMLR 2016
论文地址:download 
论文代码:download
引用次数:5292

1 域适应

  We consider classification tasks where X is the input space and Y={0,1,,L1} is the set of L possible labels. Moreover, we have two different distributions over X×Y , called the source domain DS and the target domain DT . An unsupervised domain adaptation learning algorithm is then provided with a labeled source sample S drawn i.i.d. from DS , and an unlabeled target sample T drawn i.i.d. from DTX , where DTX is the marginal distribution of DT over X .

    S={(xi,yi)}i=1n(DS)n

    T={xi}i=n+1N(DTX)n

  with N=n+n being the total number of samples. The goal of the learning algorithm is to build a classifier η:XY with a low target risk

    RDT(η)=Pr(x,y)DT(η(x)y),

  while having no information about the labels of DT .

2 Domain Divergence

  假设:如果数据来自源域,域标签为 1,如果数据来自目标域,域标签为 0

  Definition 1. Given two domain distributions  DSX  and  DTX  over  X , and a hypothesis class  H , the  H-divergence between  DSX  and  DTX  is

    dH(DSX,DTX)=2supηH|PrxDSX[η(x)=1]PrxDTX[η(x)=1]|

  H-divergence 换言之:在假设空间  H  中,找到一个函数 h,使 PrxD[h(x)=1]  尽可能大,而 PrxD[h(x)=1] 尽可能小。

  可通过计算样本 S(DSX)nT(DTX)n 之间的经验  H-divergence  来近似:

    d^H(S,T)=2(1minηH[1ni=1nI[η(xi)=0]+1ni=n+1NI[η(xi)=1]])(1)

  其中,I[a] 是指示函数:若 a 为真时,I[a]=1,否则 I[a]=0

3 Proxy Distance

  由于经验 H-divergence 难以精确计算,可使用判别 源样本与目标样本 的学习算法完成近似。

  构造新的数据集 U

    U={(xi,0)}i=1n{(xi,1)}i=n+1N(2)

  使用 H-divergence 的近似表示 Proxy A-distance(PAD)

    d^A=2(12ϵ)(3)

  其中,ϵ 为 源域和目标域样本的分类泛化误差

4 Method

  假设输入空间由 m 维向量 X=Rm 构成,隐层 Gf:XRD ,由 (W,b)RD×m×RD 参数化:

    Gf(x;W,b)=sigm(Wx+b) with sigm(a)=[11+exp(ai)]i=1|a|(4)

  预测层 Gy:RD[0,1]L,由 (V,c)RL×D×RL 参数化:

    Gy(Gf(x);V,c)=softmax(VGf(x)+c) with softmax(a)=[exp(ai)j=1|a|exp(aj)]i=1|a|

  其中 L=|Y|

  给定一个源样本 (xi,yi),使用正确标签的负对数概率:

    Ly(Gy(Gf(xi)),yi)=log1Gy(Gf(x))yi

  对神经网络的训练会导致源域上的以下优化问题:

    minW,b,V,c[1ni=1nLyi(W,b,V,c)+λR(W,b)](5)

  其中,Lyi(W,b,V,c)=Ly(Gy(Gf(xi;W,b);V,c),yi)R(W,b) 是一个正则化项。

  域正则化器引出想法:借用 Definition 1H-divergence 推导出的域正则化器。

  源样本、目标样本分别表示为

    S(Gf)={Gf(x)xS}

    T(Gf)={Gf(x)xT}

  在 Eq.1 的基础上,给出样本 S(Gf)T(Gf) 之间的经验 H-divergence

    d^H(S(Gf),T(Gf))=2(1minηH[1ni=1nI[η(Gf(xi))=0]+1ni=n+1NI[η(Gf(xi))=1]])(6)

  域分类器 Gd:RD[0,1] ,由 (u,z)RD×R 参数化,计算了输入来自源域 DSX 或目标域 DTX 的概率:

    Gd(Gf(x);u,z)=sigm(uGf(x)+z)(7)

  因此,域分类器的交叉熵损失如下:

    Ld(Gd(Gf(xi)),di)=dilog1Gd(Gf(xi))+(1di)log11Gd(Gf(xi))

  其中,di 表示第 i 个样本的二分类域标签。

  Eq.5 的目标中添加域自适应项,并给出以下正则化器:

    R(W,b)=maxu,z[1ni=1nLdi(W,b,u,z)1ni=n+1NLdi(W,b,u,z)](8)

  其中,Ldi(W,b,u,z)=Ld(Gd(Gf(xi;W,b);u,z),di)

  R(W,b) 试图近似 Eq.6H-divergence,因为 2(1R(W,b))d^H(S(Gf),T(Gf)) 的一个替代品。

  Eq.5 的完整优化目标重写如下:

    E(W,V,b,c,u,z)=1ni=1nLyi(W,b,V,c)λ(1ni=1nLdi(W,b,u,z)+1ni=n+1NLdi(W,b,u,z))(9)

  对应的参数优化 W^, V^, b^, c^, u^, z^

    (W^,V^,b^,c^)=arg minW,V,b,cE(W,V,b,c,u^,z^)(u^,z^)=arg maxu,zE(W^,V^,b^,c^,u,z)

Generalization to Arbitrary Architectures

  分类损失和域分类损失:

    Lyi(θf,θy)=Ly(Gy(Gf(xi;θf);θy),yi)Ldi(θf,θd)=Ld(Gd(Gf(xi;θf);θd),di)

  优化目标:

    E(θf,θy,θd)=1ni=1nLyi(θf,θy)λ(1ni=1nLdi(θf,θd)+1ni=n+1NLdi(θf,θd))(10)

  对应的参数优化 θ^f, θ^y, θ^d

    (θ^f,θ^y)=argminθf,θyE(θf,θy,θ^d)(11)θ^d=argmaxθdE(θ^f,θ^y,θd)(12)

   如前所述,由 Eq.11-Eq.12 定义的鞍点可以作为以下梯度更新的平稳点找到:

    θfθfμ(LyiθfλLdiθf)(13)θyθyμLyiθy(14)θdθdμλLdiθd(15)

  整体框架:

  


   

 

5 总结

  问题:

    • 存在梯度消失的问题;
    • 训练过程不稳定;

 

复制代码
for epoch in range(n_epoch):
    len_dataloader = min(len(dataloader_source), len(dataloader_target))
    data_source_iter = iter(dataloader_source)
    data_target_iter = iter(dataloader_target)

    i = 0
    while i < len_dataloader:
        p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        # training model using source data
        data_source = data_source_iter.next()
        s_img, s_label = data_source

        class_output, domain_output = my_net(input_data=s_img, alpha=alpha)
        err_s_label = loss_class(class_output, class_label)
        err_s_domain = loss_domain(domain_output, domain_label)

        # training model using target data
        t_img, _ = data_target_iter.next()
        domain_label = torch.ones(batch_size)
        domain_label = domain_label.long()
        _, domain_output = my_net(input_data=t_img, alpha=alpha)
        err_t_domain = loss_domain(domain_output, domain_label)
        err = err_t_domain + err_s_domain + err_s_label

        err.backward()
        optimizer.step()
        i += 1

def forward(self, input_data, alpha):
    feature = self.feature(input_data)
    class_output = self.class_classifier(feature)

    reverse_feature = ReverseLayerF.apply(feature, alpha)
    domain_output = self.domain_classifier(reverse_feature)
    return class_output, domain_output
复制代码
posted @   别关注我了,私信我吧  阅读(873)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现
· 【杂谈】分布式事务——高大上的无用知识?
Live2D
点击右上角即可分享
微信分享提示