ARC121F Logical Operations on Tree【DP】

给定一棵树,给每个点填 \(0\)\(1\),给每条边填 \(\text{AND}\)\(\text{OR}\),求有多少种填法满足存在一种缩边的顺序,使得每次把一条边的两个端点缩成一个点,权为原端点与边的运算值,最终点的权为 \(1\)。答案对 \(998244353\) 取模。

\(n \leq 10^5\)


不妨钦定 \(1\) 为根。考虑一个特殊的局部:如果在操作过程中,出现了某个叶子的值为 \(1\) 且连接它的边为 \(\text{OR}\),那么只需要将其放在最后操作,就一定能够得到 \(1\)。我们用有序对 \((1, \text{OR})\) 来描述这个叶子的状态。

接着我们考虑出现的叶子的其他状态:

  • \((1,\text{AND})\)\((0,\text{OR})\) :显然这个叶子对答案没有影响,因此原树合法的充要条件是上面的部分合法。

  • \((0,\text{AND})\):由于操作之后其父亲的权会变为 \(0\),我们不妨在这样的叶子出现时就立刻对其进行操作,这样能够减小它对答案的负面影响。因此原树合法的充要条件还是上面的部分合法。

也就是说,如果不存在某种操作方式使得能够出现 \((1, \text{OR})\) 这样的叶子,那么依次删掉叶子一定不劣。此时合法的充要条件是:根节点权为 \(1\),且其儿子形成的叶子的状态都是 \((1,\text{AND})\)\((0,\text{OR})\)。这是相对容易计数的。

考虑容斥,改为计算不合法的方案数。如果我们能求出不存在 \((1,\text{OR})\) 的方案数 \(f\) 和其中合法的方案数 \(g\),那么 \(f-g\) 就是不合法的方案数。考虑树形 DP,设 \(f_u\) 表示 \(u\) 的子树中不存在 \((1,\text{OR})\) 的方案数,\(g_u\) 表示 \(u\) 的子树中不存在 \((1,\text{OR})\) 且合法的方案数。根据上面的分析,初始时 \(f_u = 2,g_u = 1\)

对于 \(f\),有转移 \(f_u \gets f_u \times \prod \limits_{v \in \text{son}_u} (2 \times f_v - g_v )\)。其中 \(2 \times f_v\) 是因为 \((u,v)\) 可以是 \(\text{OR}\)\(\text{AND}\),减去 \(g_v\) 是为了排除 \(v\) 子树合法且 \((u,v)\)\(\text{OR}\),即出现 \((1,\text{OR})\) 的情况。

对于 \(g\),有转移 \(g_u \gets g_u \times \prod \limits_{v \in \text{son}_u} f_v\)。注意无论 \(v\) 子树是否合法,由于其成为叶子时状态只能是 \((1, \text{AND})\)\((0,\text{OR})\),因此 \((u,v)\) 实际上是确定的,因此 \(f_v\) 的系数为 \(1\)

时间复杂度 \(O(n)\)

code
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5, mod = 998244353;
int n, f[N], g[N];
vector <int> e[N];
int ksm(int a, int b) {
    int ret = 1;
    while (b) {
        if (b & 1) ret = 1LL * ret * a % mod;
        a = 1LL * a * a % mod, b >>= 1;
    }
    return ret;
}
void dfs(int u, int ff) {
    f[u] = 2, g[u] = 1;
    for (auto v : e[u]) {
        if (v == ff) {
            continue;
        }
        dfs(v, u);
        f[u] = 1LL * f[u] * (1LL * f[v] * 2 % mod + mod - g[v]) % mod;
        g[u] = 1LL * g[u] * f[v] % mod;
    }
}
int main() {
    ios :: sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n;
    for (int i = 1, x, y; i < n; i++) {
        cin >> x >> y;
        e[x].push_back(y);
        e[y].push_back(x);
    }
    dfs(1, 0);
    int ans = (1LL * ksm(2, 2 * n - 1) + mod - f[1] + g[1]) % mod;
    cout << ans << "\n";
    return 0;
}
posted @ 2023-02-22 21:58  came11ia  阅读(19)  评论(0编辑  收藏  举报