【论文阅读笔记】Distiling Causal Effect of Data in Class-Incremental Learning
Distiling Causal Effect of Data in Class-Incremental Learning
1. Contribution
这是一篇从因果角度思考持续学习的文章,这个思路比较新颖有意思
- 从因果角度解释了产生灾难性遗忘的原因,同时分析了 Data Replay 和 Distillation 两种持续学习方法能够在一定程度缓解灾难性遗忘的原因
- 从因果角度的视角,提出了一种等效于 Data Replay ,但不需要存储旧类样本的 Distillation 方法
- 针对新旧类的样本不均衡问题,沿着《Longtailed classification by keeping the good and removing the bad momentum causal effect》的工作,做了改进。
2. Motivation
本文中重点关注的持续方法主要是 Replay-based Method 和 Distillation-based Method。两种方法的有效缓解灾难性遗忘,但同时缺点也很明显:
- Replay-based Method 这种 end-to-end 的方法相比其他 output-end 的方法效果更好,但需要额外的存储空间;
- Distillation-based Method 不需要额外存储空间,但极度依赖于新旧类的分布情况,如果新旧类特征差异较大,Distillation loss 可能会为了尽可能保存旧类特征的同时,误导新类学习到的特征。
因此,作者提出了一个问题:“是否有一种与样本回放等效的端到端蒸馏方法?”
使用因果模型对灾难性遗忘和上述两种缓解灾难性遗忘的方法建模后,问题可以转化为“除了样本回放的方法外,是否还有其他方式,施加旧数据的影响?”
besides replay, is there another way to introduce the causal effect of the old data?
作者找到了这样的方法,也就是本文中提出的方法 Distilling Colliding Effect(DCE),实验证明这个方法可以有效提升 LUCIR 和 PODNet 的性能。
此外,作者还发现了新酒类的不均衡问题,并提出了 Incremental Momentum Effect Removal method,以去除 biased data causal effect 。
3. (Anti-) Forgetting in Causal Views
为了能够系统的解释灾难性遗忘以及缓解灾难性遗忘的因果关系,本文对持续学习中的数据、特征、预测标签使用因果图(Causal Graphs)来表示各自之间的关系。
在上面的这张图中,D 表示旧数据;I 表示用于训练的新数据;X 表示使用新模型提取的特征;\(X_{0}\) 表示使用旧模型提取得到的特征;Y 表示新模型的预测标签;\(Y_{0}\) 表示旧模型的预测标签。
通路 \(I \rightarrow X \rightarrow Y\):表示新数据使用新模型提取特征 X 后,经过新模型的分类器得到预测标签 Y。
通路 \((D,I) \rightarrow X_{0} \& (D,X_{0}) \rightarrow Y_{0}\):表示新旧数据 D I 使用旧模型提取得到特征后,经过旧模型的得到预测标签 \(Y_{0}\) 。(这里中间的 \(X_{0}\) 表示新数据 I 使用旧模型提取特征后得到的特征向量,\((D,X_{0})\) 表示旧数据 D 使用旧模型提取特征后得到的特征向量)。
通路 \(D \rightarrow I\) :为样本回放方法所添加的通路,通过回放存储的样本,使得旧类数据能够与新类数据建立联系。
通路 \(X_{0} \rightarrow X \& Y_{0}\rightarrow Y\):为在 logits 上做蒸馏的方法所添加的通路。
不连通通路 \(X_{0}\nrightarrow X\) : 为作者强调应该忽视的,作者认为虽然新旧模型间存在参数的继承,但随着模型训练,新模型从旧模型中继承得到的参数数量会呈指数下降,因此可以忽略。(参考《Overcoming catastrophic forgetting in neural networks》)
文中除了使用因果图外,还结合了数据公式进行表达,其中一切的开端都在这条式子:
这条公式(1)表示的是旧数据 D 对新模型预测标签的影响。其中,\(do(\cdot)\) 表示使用对某个单元施加具体的影响,具体来说,比如 \(do(A=a)\) 为对单元 A 赋予一个具体的影响a(可以理解为,施加动作a后,去除节点A的输入节点的影响);公式中第一行,表示有旧样本与无旧样本对最终预测标签 Y 产生的影响之差。通过留意因果图,我们发现 D 是一个外生变量,其没有父母节点,不可能受到其父母节点的影响,因此公式可以写成第二行的样子。
3.1 Forgetting
Figure 2 (a) 为对持续学习建模的因果图。产生灾难性遗忘的原因可以从因果的关系进行分析:
因为有向图中节点 D 到节点 Y 没有通路,其在 \(X_{0}\) 处被阻断了,所以式子可以写成第二行的样子,从而使得整个式子的值为 0 。这个结果说明在这种情况下,旧数据 D 对新模型的预测 Y 完全没有影响,从而造成了灾难性遗忘。
3.2 Anti-Forgetting
本文在这里分析了 数据回放 以及 知识蒸馏 之所以能够缓解灾难性遗忘的原因。
- Data Replay
数据回放的方法相当于在原来的因果图中,添加了一条 D 指向 I 的有向边,其数学公式可表示为
公式中第一行可以写成 \(P(Y|I,D=d)P(I|D=d)\) 是因为在下图所示的因果图中,D 和 I 都是外生节点,都有到达Y的通路,即共同对 Y 有影响,与同时 D 也能够对 I 产生影响。可以写成第二行,是因为 D 必须通过 I 才能够对 Y 产生影响,I 是唯一能够直接影响 Y 的 unit,所以有 \(P(Y|I,D)=P(Y|I)\) 。最后,显然式子不为0,即说明旧类数据能够对 新模型预测产生影响。数据回放的方法主要得益于 数据 I 与预测标签是一个端到端的关系。
- Feature & Label Distillation
特征蒸馏的数学表达如下,这种方法打破了 \(X_{0}\) 处的阻塞。
标签蒸馏的数学表达原文没有给出,个人写出的式子如下,这种方法打破了 \(Y_{0}\) 处的阻塞。
通过因果图的分析,我们发现现有的方法都相当于在图中增加了一条边,从而使得 D 能够与 Y 相连通。当然也有增加多条边的方法,如 iCaRL 是增加了 D→I 和 \(Y_{0}\)→Y 两条边,LUCIR 是增加了 D→I 、\(X_{0}\)→X 、\(Y_{0}\)→Y 三条边。
4. Distilling Causal Effect of Data
4.1 Distilling Colliding Effect (DCE)
通过上面因果图的分析,作者发现上述的这些方法都是避免了节点(Collider) \(X_{0}\) 或 \(Y_{0}\) 对 D 的阻塞,于是作者想通过控制节点 \(X_{0}\) ,为其设置条件,使其成为桥梁,从而建立 D 与 Y 之间的非直接联系。
作者称这种方式为 Colliding Effect(参考《Causal inference in statistics: A primer》)我们可以从生活中的一个学生努力与天赋对成绩影响的例子去理解,这个例子的因果图为 Intelligence→Grades←Efford,即学生的成绩由其努力的程度共同影响。通常来说,一个学生是否努力与其是否聪慧相互独立,但是如果我们知道一个学生能够取得一个好成绩(施加的condition),同时他是不聪明的,那我们就可以推断出他一定很努力。因此,两个独立的变量能够通过其共同的输出 collider variable (分数)产生联系。
能写出上面这个式子的原因是,在我们的设定下,\(X_{0}\) 是一个能够沟通 D 和 I 的桥梁,因此考虑总体的对 Y 的影响,可以写成 \(P(Y|I,X_{0})\) ,而紧跟着的那一项反映了 D 通过 \(X_{0}\) 对 I 产生。在第二行的公式中,作者将第二项看成对第一项(相当于新模型输出的预测概率)的加权项,由此得到本文方法的核心数学公式。
核心idea我们是理解了,但该如何理解这个权重呢?这个权重由 \(I,X_{0},D\) 3项组成,从公式上可以理解为,新输入的样本是
实际在计算 W 时,以一张图 i 为例,本模块首先会使用旧网络提取特征得到 \(\Omega_{0}(i)\) ,使用这个特征向量与其他新数据的旧特征向量计算距离,选取其中距离最小的 K 个样本,依照其距离的大小赋予相应的权重(距离越小,权重越大),最后归一化为所有 K 个权重之和为1。
这个过程的伪代码如下:
4.2 Incremental Momentum Effect Removal
作者认为,虽然我们前面已经能够保证旧数据 D 对新模型的预测标签 Y 的影响不为0,但是使用含有动量的 SGD 优化器更新参数的过程中,moving average更新的动量不可避免的会因为新旧类的不均衡问题而偏向新类。上图为作者对样本回放持续学习方法训练过程中出现的长尾分布的可视化(每一类保存10个旧样本)。
作者认为,不同于一般意义上的长尾分布问题,持续学习中的长尾分布问题有其独特的特点:
- 同一阶段内保存的旧类样本与新类样本不均衡
- 各个阶段之间,当前阶段的大类会称为下一阶段的小类
针对这个问题,作者设计了一种通过学习动量的方向,在推断时以此抵消动量影响的方法,其要点总结如下:
- 对于每一个增量阶段,通过 \(x_{m}=\mu \cdot x_{m-1} + x_{m}\),计算特征向量的 moving average
- 对于每一个增量阶段,对 moving average 更新的特征向量归一化得到当前 head direction \(h_{t}=x_{m}/||x_{M}||\)
- 对于每一个增量阶段,更新 dynamic head direction, \(h=(1-\beta)h_{t-1}+\beta h_{t}\)
- 在推断时,使用模型提取得到数据的特征向量后,用其减去新计算的 dynamic head direction,以消去 momentum 的影响,最后再由分类器分类
从因果角度的数学表达如下:
对应因果图为:
这里涉及两个参数 \(\alpha,\beta\) 都是可学习的参数
5. Experiments
本文方法的 DCE 中需要确定一个 K 最近邻的数字 K ,对 CIFAR100 和 ImageNet-Sub 使用了 K=10,对ImageNet-Full 使用了 K=1(以减少计算时间)。在 MER 中,因为本文方法没有存储的数据,所以作者抛弃了 finetune stage 并设置 \(\alpha=0.5, \beta=0.8\) 。
此外,下表中标识的 R 是每一类允许固定存放的样本数量,T 代表有 T 个增量阶段;
Forgetting 的值是越小越好
5.1 Comparisons with State-of-The-Art
5.2 Data Effect Distill vs. Data Replay
5.3 Different Replay Numbers
见表格1,R 代表每一类运行存储的类别样本数
5.4 Effectiveness of Each Causal Effect
5.5 Different Weight Assignments
这里的 Top-n 表示使用 n 个最近邻的样本计算 DCE;Rand 表示随机选取一个样本作为 DCE 中的样本;Bottom 代表选择距离最远的样本;Variant1 和 Variant2 都使用了前 10 个距离最小的样本,但 Variant1 会让样本自身的权重下降,以使得相邻样本的权重更大,Variant2 则依据余弦相似度+softmax分配权重。
5.6 Robustness of Incremental Momentum Effect Removal
这里的 N 代表增量阶段的个数
5.7 Different Size of the Initial Task
baseline 是 LUCIR