Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

Sagawa S., Koh P. W., Hashimoto T. B. and Liang P. Distributionally robust neural networks for group shifts: on the importance of regularization for worst-case generalization. In International Conference on Learning Representations (ICLR), 2020.

作者希望通过 DRO (distributionally robust optimization)

(1)minθΘ{R(θ):=supQQE(x,y)Q[(θ;x,y)]}

来解决分布不均衡的问题: 即训练后模型通常对于训练中占据大部分的类别比较友好, 而在 atypical 的类别中表现很差.

主要内容

数据集

上面三个数据集的设计思路是一样的, 这里介绍一下 Waterbirds.
Waterbirds 是基于 CUB 得到的, 其中的鸟类图片均带有像素级的mask, 基于此可以将鸟和背景分离开来. 作者考虑两种类别的鸟:

  1. Waterbirds (水鸟): albatross, auklet, cormorant, frigatebird, fulmar, gull, jaeger, kittiwake, pelican, puffin, or tern, gadwall, grebe, mallard, merganser, guillemot, or Pacific loon;
  2. 其余的为 landbird (陆鸟).

数据集的group是这般构造的, 95%的水(陆)鸟的背景为水(陆地), 5%的水(陆)鸟的背景为陆地(水), 显然后者在自然界在也是较为稀少的存在. 可以通过此数据集研究所得模型的偏好. 需要注意的是, 验证集和测试集的比例是均衡的 (50% vs 50%), 这能更好地验证模型对于每个group的表现. 所以, 这里自然而然有一个分布偏移的问题.

训练方法

作者比较不同训练方法下的表现:

  • ERM:

(2)θ^ERM:=argminθΘE(x,y)P^[(θ;(x,y))],

其中P^为训练集上的经验分布.

  • group DRO: DRO假设 (1) 中的 Q:={g=1mqgP^g:qΔm}, 其中Δm是一个m1维度的单纯形, 时间上, 我们认为Q中包含的分布由m个部分组合而成. 易得:

(3)θ^DRO=argminθΘ{supQQE(x,y)Q[(θ;x,y)]}=argminθΘ{supgqg=1gqgE(x,y)P^g[(θ;x,y)]}=argminθΘ{maxgE(x,y)P^g[(θ;x,y)]}.

故, 实际上 θ^DRO就是使得各个group的最大化经验损失最小化.

  • group adjustments DRO: 当P^和真实的分布P一致的时候, 通过 group DRO 理论上就能缓解group的偏见问题, 但是往往存在分布偏移, 所以实际上理论和实际之间存在一个泛化误差: δg=E(x,y)Pg[(θ;(x,y))]E(x,y)P^g[(θ;(x,y))], 故作者引入一个估计 δ^g=C/ng来抵消这一误差:

(4)θ^adj=argminθΘmaxg{E(x,y)P^g[(θ;x,y)]+Cng}.

其中C代表模型的拟合能力 (超参数), 1/ng 则反应了小的group相较于大的group过拟合的一个倾向程度.

  • Importance Weighting: 重加权是平衡分布的一个常用手段,

(5)θ^w:=argminθΘE(x,y,g)P^[wg(θ;(x,y))],

通常的, 选择 wg=1/EgP^[I(g=g)].

ERM vs group DRO

如上图所示, 可知:

  1. ERM 和 DRO 的 Average Accuracy 是差不多的;
  2. 在不添加正则化的时候 (standard), ERM 和 DRO 在最差的group上的测试正确率非常糟糕, 虽然训练精度已经相当不错了, 这表现了严重的过拟合;
  3. 在添加了注入 2惩罚项和早停等正则化后, DRO才能在消除group偏见上起到作用.

group DRO vs adjusted group DRO

注: 仅2.

由上图可知, 传统的 group DRO 由于泛化误差的存在, 任有很大进步空间, 这一点可由 (5) 来缓解.

ERM vs DRO vs Importance weighting

注: UW (upweighting)

由上图可知, 重加权也能起到平衡的作用, 但是较劣于 DRO . 此外, 作者还证明了在关于θ是凸的连续函数时, 二者是等价的, 但是一旦没有了凸性就无法保证了.

Online group DRO

虽然已经有方法提出如何解决 (3) 了, 但是这些方法大抵缺乏可扩展性和收敛性保证. 本文便提出了一种 Online 算法. 注意求解 (3) 实际上等价于

minθΘsupgqg=1gqgE(x,y)P^g[(θ;x,y)],

作者将 qg视作可训练的参数, 然后交替训练其与θ.

注: 作者在实际中是使用mini-batch进行训练的:

In practice, we use minibatches and a momentum term for θ.

注: 看代码, 关于 q的更新也是 mini-batch的. 此时为对应group的平均损失.

注: 我好像在 boost 之类的算法中看到过类似这种的指数上升的更新方式, 但是我并不清楚它的利和弊. 简单看来, 这个更新方式会倾向于更大的group和更难的group, 感觉和reweighting的方式还是有挺大差别的.

代码

原文代码

posted @   馒头and花卷  阅读(1404)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
历史上的今天:
2019-05-06 A ROBUST KERNEL PCA ALGORITHM
点击右上角即可分享
微信分享提示