迁移学习(ADDA)《Adversarial Discriminative Domain Adaptation》

Note:[ wechat:Y466551 | 可加勿骚扰,付费咨询 ]

论文信息

论文标题:Adversarial Discriminative Domain Adaptation
论文作者:Eric Tzeng, Judy Hoffman, Kate Saenko, Trevor Darrell
论文来源:CVPR 2017
论文地址:download 
论文代码:download
引用次数:3257

1 介绍

  动机

    • 在做分类或者域偏移较大的任务,之前的方法比较不那么令人满意;
    • 先前的判别方法可以处理更大的域迁移,但对模型施加了绑定的权重,并且没有利用基于 GAN 的损失;

2 对抗域适应

  源域分类器训练:

    $\underset{M_{s}, C}{\text{min}} \quad \mathcal{L}_{\mathrm{cls}}\left(\mathbf{X}_{s}, Y_{t}\right)=  \mathbb{E}_{\left(\mathbf{x}_{s}, y_{s}\right) \sim\left(\mathbf{X}_{s}, Y_{t}\right)}-\sum\limits _{k=1}^{K} \mathbb{1}_{\left[k=y_{s}\right]} \log C\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\quad\quad(1)$

  域鉴别器训练:

    $\begin{array}{l}\mathcal{L}_{\text {adv }_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right)= -\mathbb{E}_{\mathbf{x}_{s} \sim \mathbf{X}_{s}}\left[\log D\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\right] -\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log \left(1-D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right)\right]\end{array} \quad\quad(2)$

  域对抗技术的通用公式如下:

    $\begin{array}{l}\underset{D}{\text{min}}  & \mathcal{L}_{\mathrm{adv}_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right) \\\underset{M_{s}, M_{t}}{\text{min}}  & \mathcal{L}_{\mathrm{adv}_{M}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, D\right) \\\text { s.t. } & \psi\left(M_{s}, M_{t}\right)\end{array}\quad\quad(3)$

2.1 源域和目标域映射

  

  归结为三个问题:

    • 选择生成式模型还是判别式模型?
    • 针对源域与目标域的映射是否共享参数?
    • 损失函数如何定义?

2.2 Adversarial losses

  DANN 域鉴别器和特征提取器训练目标的关系:

    $\mathcal{L}_{\text {adv }_{M}}=-\mathcal{L}_{\mathrm{adv}_{D}}\quad\quad(6)$

  问题:在训练的早期,鉴别器快速收敛,导致梯度消失;

  GAN 训练目标:

    $\underset {G}{\text{min }}\underset {D}{\text{max }}V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]$

  拆解:

    $\mathcal{L}_{\mathrm{adv}_{M}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, D\right)=-\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right] $

    $\begin{array}{l}\mathcal{L}_{\text {adv }_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right)= -\mathbb{E}_{\mathbf{x}_{s} \sim \mathbf{X}_{s}}\left[\log D\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\right] -\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log \left(1-D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right)\right]\end{array} $

  注意:训练鉴别器使用的是正常标签,训练生成器使用的是倒置标签

adversarial_loss = torch.nn.BCELoss()  # 损失函数(二分类交叉熵损失)
generator = Generator()           #生成器
discriminator = Discriminator()   #鉴别器

optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))  # 生成器优化器
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))   # 鉴别器优化器

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  #torch.Size([64, 1])
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)   #torch.Size([64, 1])
        real_imgs = Variable(imgs.type(Tensor))     #torch.Size([64, 1, 28, 28])   真实数据

        # ----------------------> 训练生成器  [生成器使用噪声数据,使得其尽可能为真,迷惑鉴别器]
        optimizer_G.zero_grad()
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))    #torch.Size([64, 100])
        gen_imgs = generator(z)        #torch.Size([64, 1, 28, 28])
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # ----------------------> 训练鉴别器  [ 尽可能将真实数据和噪声数据区分开]
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
GAN code

  本文采用的方法类似于  GAN 。

3 对抗性域适应

  与之前方法不同: 

  

  本文方法:

  

  首先:预训练,使用源域训练一个分类器;[ 公式 9 第一个子公式]

  其次:对抗性训练

    • 使用源域和目标域数据,训练一个域鉴别器 Discriminator ,是的鉴别器尽可能区分源域和目标域数据 ;[ 公式 9 第二个子公式]  
    • 使用目标域数据,训练目标域特征提取器,尽可能使得域鉴别器区分不出目标域样本;[ 公式 9 第三个子公式]  

  最后:测试;

  ADDA 优化目标:

    $\begin{array}{l}\underset{M_{s}, C}{\text{min}} \quad \mathcal{L}_{\mathrm{cls}}\left(\mathbf{X}_{s}, Y_{s}\right) &=&-\mathbb{E}_{\left(\mathbf{x}_{s}, y_{s}\right) \sim\left(\mathbf{X}_{s}, Y_{s}\right)} \sum_{k=1}^{K} \mathbb{1}_{\left[k=y_{s}\right]} \log C\left(M_{s}\left(\mathbf{x}_{s}\right)\right) \\\underset{D}{\text{min}}  \quad\mathcal{L}_{\text {adv }_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right)&=& -\mathbb{E}_{\mathbf{x}_{s} \sim \mathbf{X}_{s}}\left[\log D\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\right] \text { - } \mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log \left(1-D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right)\right] \\\underset{M_{t}}{\text{min}}  \quad \mathcal{L}_{\operatorname{adv}_{M}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, D\right)&=& -\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right] \\\end{array} \quad\quad(9)$

 

for epoch in range(params.num_epochs):
    data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
    for step, ((images_src, _), (images_tgt, _)) in data_zip:
        # 2.1 train discriminator #
        images_src = make_variable(images_src)
        images_tgt = make_variable(images_tgt)
        optimizer_discriminator.zero_grad()
        feat_src = src_encoder(images_src) #torch.Size([50, 500])
        feat_tgt = tgt_encoder(images_tgt) #torch.Size([50, 500])
        feat_concat = torch.cat((feat_src, feat_tgt), 0)  #torch.Size([100, 500])
        pred_concat = discriminator(feat_concat.detach())  #torch.Size([100, 2])
        label_src = make_variable(torch.ones(feat_src.size(0)).long())
        label_tgt = make_variable(torch.zeros(feat_tgt.size(0)).long())
        label_concat = torch.cat((label_src, label_tgt), 0)  #torch.Size([100])
        loss_discriminator = criterion(pred_concat, label_concat)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # 2.2 train target encoder #
        optimizer_discriminator.zero_grad()
        optimizer_tgt.zero_grad()
        feat_tgt = tgt_encoder(images_tgt)
        pred_tgt = discriminator(feat_tgt)
        label_tgt = make_variable(torch.ones(feat_tgt.size(0)).long())
        loss_tgt = criterion(pred_tgt, label_tgt)
        loss_tgt.backward()
        optimizer_tgt.step()
ADDA Code
posted @ 2023-01-28 22:06  图神经网络  阅读(945)  评论(0编辑  收藏  举报
Live2D