论文阅读-Causality Inspired Framework for Model Interpretation

标题:A Causality Inspired Framework for Model Interpretation
关键词:自然语言处理,因果推理,可解释机器学习
论文链接https://dl.acm.org/doi/pdf/10.1145/3580305.3599240
会议KDD

1. 简介

解释(explanation) 能否揭示 模型表现的根本原因(root cause)是XAI的重要问题。
文章提出了Causality Inspired Model Interpreter (CIMI),一种基于因果推理的解释器。

  • 在XAI中 变量集合(a set of variables)可以作为模型预测的可能原因(possible causes),如果其满足因果充分性假设(causal sufficiency assumption)。
  • 核心问题:如何有效地 发现突出的共同原因(prominent common causes),并能从大量特征和数据点推广到不同实例(泛化性generalizability)。

2. 因果图

img

  • X:输入实例
  • E:模型预测的泛化因果解释(generalized causal explanation for predeiction)
  • U:非解释部分(non-explanation)
  • Y^:预测结果
  • M:掩码,和X作运算之后能得到E
  • g(.):解释器Interpreter.

g(X)=E
E=MX
U=(1M)X
是逐元素乘法,Mi[0,1] 表示特征对输出的贡献。

3. 框架结构

img
解释器 Interpreter g(.)由 encoder fe(.) 和 decoder ϕ(.)组成
fe(.)是预训练模型Bert的encoder
ϕ(.)是唯一可以训练的模型,由1层LSTM和2层MLP组成。

  • 1-layer LSTM: hidden size is 64
  • 2-layers MLP: 64 × 16, 16 × 2

g(x)=ϕ([fe(x);vx]1)

  • [fe(x);vx]1表示Bert编码器 和 词嵌入vx在axis 1上作连接操作。
  • d: 嵌入(embedding)的维度
  • ϕ(.)R|x|×2d[0,1]|x|×1
  • 输出第i维度的值:token i 被选中作为解释的概率

信息瓶颈理论infomation bottleneck theory 在前向传播过程中,神经网络会逐渐专注于输入中最重要的部分,过滤掉不重要的部分。(信息量是逐渐衰减的)
黑盒模型的编码器fe(x)能够通过已经训练好的模型过滤掉一部分噪音。

4. 损失函数

1. causal sufficiency loss

xe的部分对于预测f(x)的结果已经足够(sufficient)了,xu对于预测没有帮助。
img

2. causal intervention loss

对非解释部分进行线性插值,不会影响解释的生成。g(X)=E
x是从X中随机抽样得到。
img
img

3. weakly supervised loss

为了防止生成平凡解(包含所有token的解),设置了此弱监督损失。

  • 最大化实例x中的 token 被包括在解释xe中的概率(maximizing the probability that the token in instance 𝑥 is included in 𝑥𝑒)
  • 最小化不在x中的token(噪音)被预测为解释的概率(minimizing the probability that a token not in 𝑥 (noise) is predicted to be the explanation)

img

5. 实验

5.1 评价指标

1. causul sufficiency(3个)

  • Decison Flip-Fraction of Tokens (DFFOT):改变模型预测所需要修改重要token的最小部分(minimum fraction)(越小越好
    img
  • Comprehensiveness (COMP):移除important token,原预测类的概率变化量(越大越好
    img
  • Sufficiency (SUFF):只保留import token,原预测类的概率变化量(越小越好
    img

2. explanation generalizability(1个)

  • Average Sensitivity (AvgSen):当输入被扰动时,一个解释的平均敏感度。(在实验中对每个实例替换5个token,并计算top-10 important token的敏感度)。
    img

5.2 Faithfulness Comparison

img

  • 测量不同方法的3个指标
  • 计算CIMI方法是否有显著提升

img
不同解释长度下的COMP指标的比较

5.3 Generalizability Comparison

用AvgSen作为衡量泛化能力的指标
img
在4个数据集上CIMI的top-10 token中有8个能保持一致,说明该方法能够抓住具有不变的泛化特征(invariant generalizable features)。

5.4 消融实验

尝试移除不同的模块,观察模型表现的变化,进而证明模型设计理念的正确性。
img
img

5.5 采样效率(Sampling Efficiency)

用 各种基于扰动的方法 在 相同前向传播次数下 的表现来衡量 采样效率。
img

5.6 不同离散化策略的表现

img
原来的方法采用的是可微的Softmax,本实验尝试将其替换为 Gumbel-Softmax 和 Deep Hash Learning 两个离散函数。
离散化的掩码能够提高解释的泛化能力,但是以模型表现下降为代价。

5.7 Usefulness Evaluation 模型排错

img
利用模型生成的解释剔除shortcut features。在CIMI debugging 之后模型表现是最好的。

posted @   Frank23  阅读(144)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示