论文阅读-Causality Inspired Framework for Model Interpretation

标题:A Causality Inspired Framework for Model Interpretation
关键词:自然语言处理,因果推理,可解释机器学习
论文链接https://dl.acm.org/doi/pdf/10.1145/3580305.3599240
会议KDD

1. 简介

解释(explanation) 能否揭示 模型表现的根本原因(root cause)是XAI的重要问题。
文章提出了Causality Inspired Model Interpreter (CIMI),一种基于因果推理的解释器。

  • 在XAI中 变量集合(a set of variables)可以作为模型预测的可能原因(possible causes),如果其满足因果充分性假设(causal sufficiency assumption)。
  • 核心问题:如何有效地 发现突出的共同原因(prominent common causes),并能从大量特征和数据点推广到不同实例(泛化性generalizability)。

2. 因果图

img

  • X:输入实例
  • E:模型预测的泛化因果解释(generalized causal explanation for predeiction)
  • U:非解释部分(non-explanation)
  • \(\hat{Y}\):预测结果
  • \(M\):掩码,和X作运算之后能得到E
  • \(g(.)\):解释器Interpreter.

\(g(X) = E\)
\(E = M \odot X\)
\(U = (1-M) \odot X\)
\(\odot\)是逐元素乘法,\(M_i\in[0,1]\) 表示特征对输出的贡献。

3. 框架结构

img
解释器 Interpreter \(g(.)\)由 encoder \(f_e(.)\) 和 decoder \(\phi(.)\)组成
\(f_e(.)\)是预训练模型Bert的encoder
\(\phi(.)\)是唯一可以训练的模型,由1层LSTM和2层MLP组成。

  • 1-layer LSTM: hidden size is 64
  • 2-layers MLP: 64 × 16, 16 × 2

\(g(x) = \phi([f_e(x);v_x]_1)\)

  • \([f_e(x); v_x]_1\)表示Bert编码器 和 词嵌入\(v_x\)在axis 1上作连接操作。
  • \(d\): 嵌入(embedding)的维度
  • $\phi(.): $$R^{|x| \times 2d} \rightarrow [0,1]^{|x| \times 1}$
  • 输出第i维度的值:token i 被选中作为解释的概率

信息瓶颈理论infomation bottleneck theory 在前向传播过程中,神经网络会逐渐专注于输入中最重要的部分,过滤掉不重要的部分。(信息量是逐渐衰减的)
黑盒模型的编码器\(f_e(x)\)能够通过已经训练好的模型过滤掉一部分噪音。

4. 损失函数

1. causal sufficiency loss

\(x_e\)的部分对于预测\(f(x)\)的结果已经足够(sufficient)了,\(x_u\)对于预测没有帮助。
img

2. causal intervention loss

对非解释部分进行线性插值,不会影响解释的生成。\(g(X) = E\)
\(x'\)是从\(X\)中随机抽样得到。
img
img

3. weakly supervised loss

为了防止生成平凡解(包含所有token的解),设置了此弱监督损失。

  • 最大化实例x中的 token 被包括在解释\(x_e\)中的概率(maximizing the probability that the token in instance 𝑥 is included in 𝑥𝑒)
  • 最小化不在x中的token(噪音)被预测为解释的概率(minimizing the probability that a token not in 𝑥 (noise) is predicted to be the explanation)

img

5. 实验

5.1 评价指标

1. causul sufficiency(3个)

  • Decison Flip-Fraction of Tokens (DFFOT):改变模型预测所需要修改重要token的最小部分(minimum fraction)(越小越好
    img
  • Comprehensiveness (COMP):移除important token,原预测类的概率变化量(越大越好
    img
  • Sufficiency (SUFF):只保留import token,原预测类的概率变化量(越小越好
    img

2. explanation generalizability(1个)

  • Average Sensitivity (AvgSen):当输入被扰动时,一个解释的平均敏感度。(在实验中对每个实例替换5个token,并计算top-10 important token的敏感度)。
    img

5.2 Faithfulness Comparison

img

  • 测量不同方法的3个指标
  • 计算CIMI方法是否有显著提升

img
不同解释长度下的COMP指标的比较

5.3 Generalizability Comparison

用AvgSen作为衡量泛化能力的指标
img
在4个数据集上CIMI的top-10 token中有8个能保持一致,说明该方法能够抓住具有不变的泛化特征(invariant generalizable features)。

5.4 消融实验

尝试移除不同的模块,观察模型表现的变化,进而证明模型设计理念的正确性。
img
img

5.5 采样效率(Sampling Efficiency)

用 各种基于扰动的方法 在 相同前向传播次数下 的表现来衡量 采样效率。
img

5.6 不同离散化策略的表现

img
原来的方法采用的是可微的Softmax,本实验尝试将其替换为 Gumbel-Softmax 和 Deep Hash Learning 两个离散函数。
离散化的掩码能够提高解释的泛化能力,但是以模型表现下降为代价。

5.7 Usefulness Evaluation 模型排错

img
利用模型生成的解释剔除shortcut features。在CIMI debugging 之后模型表现是最好的。

posted @ 2024-04-09 19:41  Frank23  阅读(99)  评论(0编辑  收藏  举报