McGan: Mean and Covariance Feature Matching GAN

Mroueh Y, Sercu T, Goel V, et al. McGan: Mean and Covariance Feature Matching GAN[J]. arXiv: Learning, 2017.

@article{mroueh2017mcgan:,
title={McGan: Mean and Covariance Feature Matching GAN},
author={Mroueh, Youssef and Sercu, Tom and Goel, Vaibhava},
journal={arXiv: Learning},
year={2017}}

利用均值和协方差构建IPM, 获得相应的mean GAN 和 covariance gan.

主要内容

IPM:

dF(P,Q)=supfF|ExPf(x)ExQf(x)|.

F是对称空间, 即fFfF,可得

dF(P,Q)=supfF{ExPf(x)ExQf(x)}.

Mean Matching IPM

Fv,w,p:={f(x)=v,Φw(x)|vRm,vp1,Φw:XRm,wΩ},

其中p表示p范数, Φw往往用网络来表示, 我们可通过截断w来使得Fv,w,p为有界线性函数空间(有界从而使得后面推导中sup成为max).

在这里插入图片描述
其中

μw(P)=ExP[Φw(x)]Rm.

最后一个等式的成立是因为:

x=max{v,x|v1},

p的对偶范数是q,1p+1q=1.

prime

整个GAN的训练过程即为

(3)mingθmaxwΩmaxv,vp1Lμ(v,w,θ),

其中

Lμ(v,w,θ)=v,ExPrΦw(x)Ezp(z)Φw(gθ(z)).

估计形式为
在这里插入图片描述

dual

也有对应的dual形态

(4)mingθmaxwΩμw(Pr)μw(Pθ)q.

在这里插入图片描述

Covariance Feature Matching IPM

FU,V,w:={f(x)=j=1kuj,Φw(x)vj,Φw(x),ui,uj=vi,vj=0,ij,else1},

等价于

FU,V,w:={f(x)=UTΦw(x),VTΦw(x),UTU=Ik,VTV=Ik,wΩ}.

并有
在这里插入图片描述

其中[A]k表示Ak阶近似, 如果A=iσiuiviT, σ1σ2,, 则[A]k=i=1kσiuiviT. Om,k:={MRm×k|MTM=Ik}, A=iσi表示算子范数.

prime

(6)mingθmaxwΩmaxU,VPm,kLσ(U,V,w,θ),

其中

Lσ(U,V,w,θ)=ExPrUTΦw(x),VTΦw(x)EzpzUTΦw(gθ(z)),VTΦw(gθ(z)).

采用下式估计

在这里插入图片描述

dual

(7)mingθmaxwΩ[Σw(Pr)Σw(Pθ)]k.

注: 既然Σw(Pr)Σw(Pθ)是对称的, 为什么UV? 因为虽然其对称, 但是并不(半)正定, 所以vi=ui也是有可能的.

算法

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

代码

未经测试.



import torch
import torch.nn as nn
from torch.nn.functional import relu
from collections.abc import Callable



def preset(**kwargs):
    def decorator(func):
        def wrapper(*args, **nkwargs):
            nkwargs.update(kwargs)
            return func(*args, **nkwargs)
        wrapper.__doc__ = func.__doc__
        wrapper.__name__ = func.__name__
        return wrapper
    return decorator


class Meanmatch(nn.Module):

    def __init__(self, p, dim, dual=False, prj='l2'):
        super(Meanmatch, self).__init__()
        self.norm = p
        self.dual = dual
        if dual:
            self.dualnorm = self.norm
        else:
            self.init_weights(dim)
            self.projection = self.proj(prj)


    @property
    def dualnorm(self):
        return self.__dualnorm

    @dualnorm.setter
    def dualnorm(self, norm):
        if norm == 'inf':
            norm = float('inf')
        elif not isinstance(norm, float):
            raise ValueError("Invalid norm")

        p = 1 / (1 - 1 / norm)
        self.__dualnorm = preset(p=p, dim=1)(torch.norm)


    def init_weights(self, dim):
        self.weights = nn.Parameter(torch.rand((1, dim)),
                                    requires_grad=True)

    @staticmethod
    def _proj1(x):
        u = x.max()
        if u <= 1.:
            return x
        l = 0.
        c = (u + l) / 2
        while (u - l) > 1e-4:
            r = relu(x - c).sum()
            if r > 1.:
                l = c
            else:
                u = c
            c = (u + l) / 2
        return relu(x - c)

    @staticmethod
    def _proj2(x):
        return x / torch.norm(x)

    @staticmethod
    def _proj3(x):
        return x / torch.max(x)

    def proj(self, prj):
        if prj == "l1":
            return self._proj1
        elif prj == "l2":
            return self._proj2
        elif prj == "linf":
            return self._proj3
        else:
            assert isinstance(prj, Callable), "Invalid prj"
            return prj



    def forward(self, real, fake):
        temp = (real - fake).mean(dim=1)
        if self.dual:
            return self.dualnorm(temp)
        elif not self.training and self.dual:
            raise TypeError("just for training...")
        else:
            self.weights.data = self.projection(self.weights.data) #some diff here!!!!!!!!!!
            return self.weights @ temp



class Covmatch(nn.Module):

    def __init__(self, dim, k):
        super(Covmatch, self).__init__()
        self.init_weights(dim, k)

    def init_weights(self, dim, k):
        temp1 = torch.rand((dim, k))
        temp2 = torch.rand((dim, k))
        self.U = nn.Parameter(temp1, requires_grad=True)
        self.V = nn.Parameter(temp2, requires_grad=True)

    def qr(self, w):
        q, r = torch.qr(w)
        sign = r.diag().sign()
        return q * sign

    def update_weights(self):
        self.U.data = self.qr(self.U.data)
        self.V.data = self.qr(self.V.data)

    def forward(self, real, fake):
        self.update_weights()
        temp1 = real @ self.U
        temp2 = real @ self.V
        temp3 = fake @ self.U
        temp4 = fake @ self.V
        part1 = torch.trace(temp1 @ temp2.t()).mean()
        part2 = torch.trace(temp3 @ temp4.t()).mean()
        return part1 - part2


posted @   馒头and花卷  阅读(594)  评论(0编辑  收藏  举报
编辑推荐:
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
点击右上角即可分享
微信分享提示