论文信息
论文标题:Deep CORAL: Correlation Alignment for Deep Domain Adaptation
论文作者:Baochen Sun, Kate Saenko
论文来源:ECCV 2016
论文地址:download
论文代码:download
引用次数:2203
1 介绍
解决的问题:深度神经网络可以在大规模的标注数据中学校到特征,但是输入数据分布不同的时候泛化不是很好。因此提出了域适应来弥补性能。本文针对目标域没有标注数据情况,对 进行了改进。
方法用线性变换方法将源域和目标域分布的二阶统计特征进行对齐,对于无监督域适应效果很好。问题出在依赖的是线性变换,而且不是端到端训练。训练分为两步,首先提取特征,应用变换,然后训练 分类。
2 方法
模型框架:
设源域训练样本 ,标签 。无标签的目标域数据 。其中, 为网络 的输出维度。令 、 分别表示第 个源域、目标域样本的第 维特征。 表示特征协方差矩阵。
是源域和目标域特征的 协方差距离:
其中,
其中, 代表全 的列向量。
上述公式的梯度表达式:
损失函数:
三个版本的
def CORAL(source, target):
# source.shape = torch.Size([200, 31])
# target.shape = torch.Size([56, 31])
d = source.data.shape[1] #31
# source covariance 计算协方差矩阵
xm = torch.mean(source, 0, keepdim=True) - source #torch.Size([200, 31])
xc = xm.t() @ xm #torch.Size([31, 31]) #torch.Size([31, 31])
# target covariance
xmt = torch.mean(target, 0, keepdim=True) - target
xct = xmt.t() @ xmt #torch.Size([31, 31])
# frobenius norm between source and target
loss = torch.mean(torch.mul((xc - xct), (xc - xct)))
loss = loss/(4*d*d)
return loss
def CORAL1(self, source, target):
device = source.device
d = source.size(1)
ns, nt = source.size(0), target.size(0)
# source covariance
tmp_s = torch.ones((1, ns)).to(device) @ source
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
# target covariance
tmp_t = torch.ones((1, nt)).to(device) @ target
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
# frobenius norm
loss = (cs - ct).pow(2).sum().sqrt()
loss = loss / (4 * d * d)
return loss
def CORAL2(self, x, y):
# x.shape = torch.Size([32, 2048])
# y.shape = torch.Size([32, 2048])
mean_x = x.mean(0, keepdim=True) # torch.Size([1, 2048])
mean_y = y.mean(0, keepdim=True) # torch.Size([1, 2048])
cent_x = x - mean_x
cent_y = y - mean_y
cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)
mean_diff = (mean_x - mean_y).pow(2).mean()
cova_diff = (cova_x - cova_y).pow(2).mean()
return mean_diff + cova_diff
因上求缘,果上努力~~~~ 作者:别关注我了,私信我吧,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/17067133.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!