Masked Gradient-Based Causal Structure Learning

Ng I., Fang Z., Zhu S., Chen Z. and Wang J. Masked Gradient-Based Causal Structure Learning. arXiv preprint arXiv:1911.10500, 2019.

非线性, 自动地学习因果图.

主要内容

NOTEARS将有向无环图凝练成了易处理的条件, 本文将这种思想扩展至非线性的情况:

\[X_i = f_i(X_{\mathrm{pa}(i)}) + \epsilon_i, \]

其中\(X_i\)是因果图的结点, \(X_{\mathrm{pa}(i)}\)是其父结点, \(\epsilon\)是无关的噪声.

上述等式等价于

\[X_i = f_i(A_i \circ X) + \epsilon_i, \]

\(A_i\)是邻接矩阵\(A=[A_1|A_2|\cdots|A_d] \in \{0, 1\}^{d\times d}\)的第i列, \(A_{ij}=1\)表示结点\(X_i\)直接作用于\(X_j\).

所以本文的目标就可以转换为如何估计\(A\)(实际上有了\(A\)也就知道了因果图了). \(A\)应当满足的条件:

  1. \(A\) 能够表示有向无环图;
  2. \(X_i\)\(f(A_i \circ X)\)必须接近, 比如用常见的

\[\|X_i - f(A_i \circ X)\|_2^2 \]

来度量.

直接处理非常麻烦, 首先对上面的问题进行放松, 等价于

\[X_i = f_i(W_i \circ X) + \epsilon_i, \]

此时\(A = \mathcal{A}(W)\), 即

\[W_{ij} \not = 0 \rightarrow A_{ij} = 1; W_{ij} = 0 \rightarrow A_{ij} = 0. \]

本文更进一步, 令

\[W = g_{\tau}(U), \quad U \in \mathbb{R}^{d \times d}. \]

\[[g_{\tau}(U)]_{ij} = \sigma((u_{ij} + g) / \tau) = \frac{1}{1 + \exp(-(u_{ij}+ (g_1 - g_0)) / \tau)}, \]

其中

\[g = g_1 - g_0, \: g_i \mathop{\sim}\limits^{i.i.d.} \mathrm{Gumbel}(0, 1). \]

注: Gumbel.

此类操作能保证\(g_{\tau}(U) \in (0, 1)^{d\times d}\), 此时能够把\([g_{\tau}(U)]_{ij}\)看成是\(X_i\), \(X_j\)的关系的紧密型的度量, 在这种情况下

\[[g_{\tau}(U)]_{ij} \le \omega \Rightarrow A_{ij} = 0. \]

或许会问, 为什么不用sigmoid而用一个这么麻烦的东西, 原因是当\(\tau\)足够小的时候(如本文取的0.2), \([g_{\tau}(U)]_{ij}\)非常接近\(0\)或者\(1\), 而用sigmoid, 作者发现这些值都接近0, 不能很好的模拟有向无环图, 故采用了这个方案.

接下来, 只需要满足

\[\mathbb{E}[\mathrm{tr}(e^{g_{\tau}(U)}) - d] = 0, \]

即可保证\(g_{\tau}(U)\)能够代表有效无环图. 在实际中, 只需

\[\mathbb{E}[\mathrm{tr}(e^{g_{\tau}(U)}) - d] \le \xi. \]

注: 期望是关于\(g\)的.

最终的目标

总结下来,

\[\min_{U, \theta} \quad \mathbb{E}_g[\frac{1}{2n} \sum_{k=1}^n \mathcal{L}(x^{(k)}, f(g_{\tau}, x^{(k)}; \theta))] \\ \mathrm{s.t.} \quad \mathbb{E}_g[\mathrm{tr}(e^{g_{\tau}(U)}) - d] \le \xi. \]

注: \(\mathbb{E}\)是关于\(g\)的, \(n\)的观测数据的总数.

进一步地, 我们希望\(g_{\tau}\)是稀疏的, 故加上正则化项:

\[\min_{U, \theta} \quad \mathbb{E}_g[\frac{1}{2n} \sum_{k=1}^n \mathcal{L}(x^{(k)}, f(g_{\tau}, x^{(k)}; \theta)) + \lambda \|g_{\tau}(U)\|_1] \\ \mathrm{s.t.} \quad \mathbb{E}_g[\mathrm{tr}(e^{g_{\tau}(U)}) - d] \le \xi. \]

利用augmented Lagrange multiplier, 可得

\[L_p(U, \phi, \alpha) = \mathbb{E}_g[\frac{1}{2n} \sum_{k=1}^n \mathcal{L}(x^{(k)}, f(g_{\tau}, x^{(k)}; \theta)) + \lambda \|g_{\tau}(U)\|_1 + \alpha h(U)] + \frac{\rho}{2} (\mathbb{E}[h(U)])^2, \]

其中\(h(U):= \mathrm{tr}(e^{g_{\tau}(U)}) - d\).

采用分布更新:

\[U^{t+1}, \theta^{t+1} = \arg \min_{U, \theta} L_{\rho^t}(U, \phi, \alpha^t); \\ \alpha^{t+1} = \alpha^t + \rho^t \mathbb{E}[h(U^{t+1})]; \\ \rho^{t+1} = \left \{ \begin{array}{ll} \beta \rho^t, & \mathrm{if} \: \mathbb{E}[h(U^{t+1})] \ge \gamma \mathbb{E}[h(U^t)], \\ \rho^t, & \mathrm{otherwise}. \end{array} \right . \]

其中第一步使用Adam执行1000次迭代计算的.

文中还讨论了后处理的一些方法, 和\(A\)是否唯一.

代码

GES and PC

CAM

NOTEARS

DAG-GNN

GraN-DAG

posted @ 2021-05-29 20:42  馒头and花卷  阅读(224)  评论(2编辑  收藏  举报