【题解】CF1616H Keep XOR Low
很好计数题,爱来自汐斯塔。
思路
01Trie 上 dp.
首先根据两两异或想到 01Trie,既然是计数自然考虑在 01Trie 上 dp.
先将 \(a\) 中的所有数插入 01Trie.
最直观的想法是按位 dp,也就是令 \(f[u]\) 表示 01Trie 上在 \(u\) 的子树内选取的合法方案数。
考虑 \(x\) 在 \(u\) 所代表的二进制位上的取值,发现如果是 \(1\) 的话下面的两棵子树之间会相互限制,所以很难转移。
考虑改改状态,设成 \(f[u1][u2]\) 表示在 01Trie 上的两棵子树 \(u1, u2\) 内选取的方案数。
注意这里的 \(u1, u2\) 是同一层的结点,也就是代表的二进制位相同。并且我们钦定 \(u1, u2\) 在更高的二进制位上 异或的结果小于 \(x\).
至于这是怎么想到的,我不知道,懂哥麻烦教教。
然后考虑根据 \(x\) 在 \(u1\) 子结点 代表的二进制位上的取值分类讨论:
先记 01Trie 中结点 \(u\) 的左子结点为 \(\operatorname{ls}(u)\),右子结点为 \(\operatorname{rs}(u)\),子树中存在数的个数为 \(\operatorname{sz}(u)\).
-
\(x\) 在这一位上是 \(0\).
那么 \(\operatorname{ls}(u1)\) 和 \(\operatorname{rs}(u2)\),\(\operatorname{ls}(u2)\) 和 \(\operatorname{rs}(u1)\) 的子树不能同时选取。
-
只在 \(u1\) 或 \(u2\) 的子树内选,因为钦定的条件可以随意选,方案数为 \((2^{\operatorname{sz}(\operatorname{ls}(u1))} - 1) \cdot (2^{\operatorname{sz}(\operatorname{rs}(u1))} - 1) + (2^{\operatorname{sz}(\operatorname{ls}(u2))} - 1) \cdot (2^{\operatorname{sz}(\operatorname{rs}(u2))} - 1)\).(排除为空的情况)
-
在 \(u1, u2\) 的同向子结点的子树内选,方案数为 \(f[\operatorname{ls}(u1)][\operatorname{ls}(u2)] + f[\operatorname{rs}(u1)][\operatorname{rs}(u2)]\).
-
-
\(x\) 在这一位上是 \(1\).
注意到 \(\operatorname{ls}(u1)\) 和 \(\operatorname{ls}(u2)\),以及 \(\operatorname{rs}(u1)\) 和 \(\operatorname{rs}(u2)\) 之间异或的值一定小于等于 \(x\)。
所以现在相互之间存在约束的只有 \(\operatorname{ls}(u1)\) 和 \(\operatorname{rs}(u2)\),以及 \(\operatorname{ls}(u2)\) 和 \(\operatorname{rs}(u1)\).
这里有一个相当巧妙的结论。我们发现钦定其中一组约束成立的时候,另一组约束相较于这组约束是独立的。例如当前选取的方案满足 \(\operatorname{ls}(u2)\) 和 \(\operatorname{rs}(u1)\) 的约束时(只在其中选取),已经选取的数不会影响到 \(\operatorname{ls}(u1)\) 和 \(\operatorname{rs}(u2)\) 的选取方式。
这意味着我们只需要分别求出满足这两种约束的方案总数,然后相乘就行。这里的意义实际上是在 \(\operatorname{ls}(u1), \operatorname{rs}(u1), \operatorname{ls}(u2), \operatorname{rs}(u2)\) 四棵子树内在满足两组约束的前提下选数。
方案总数为 \(f[\operatorname{ls}(u1)][\operatorname{rs}(u2)] \cdot f[\operatorname{ls}(u2)][\operatorname{rs}(u1)]\).
边界条件是递归完所有的二进制位后可以任意选取(不存在约束条件)。
注意当 \(u1 = u2\) 的时候需要特判。
时间复杂度是 \(O(n \log |V|)\).
代码
#include <cstdio>
using namespace std;
typedef long long ll;
const int maxn = 1.5e5 + 5;
const int tr_sz = maxn * 40;
const int mod = 998244353;
int n, x;
int cnt_nd = 1;
int a[maxn], pw[maxn];
int son[tr_sz][2], sz[tr_sz];
void insert(int x)
{
int u = 1;
for (int i = 30; i >= 0; i--)
{
int c = (x >> i) & 1;
if (!son[u][c]) son[u][c] = ++cnt_nd;
sz[u]++, u = son[u][c];
}
sz[u]++;
}
int dfs(int u1, int u2, int dep)
{
if ((!u1) || (!u2)) return pw[sz[u1 | u2]];
if (u1 == u2)
{
if (dep < 0) return pw[sz[u1]];
int ls = son[u1][0], rs = son[u1][1];
if ((x >> dep) & 1) return dfs(ls, rs, dep - 1);
return ((ll)dfs(ls, ls, dep - 1) + dfs(rs, rs, dep - 1) - 1 + mod) % mod;
}
if (dep < 0) return pw[sz[u1] + sz[u2]];
int ls1 = son[u1][0], rs1 = son[u1][1], ls2 = son[u2][0], rs2 = son[u2][1];
if ((x >> dep) & 1) return (ll)dfs(ls1, rs2, dep - 1) * dfs(rs1, ls2, dep - 1) % mod;
int res = ((ll)dfs(ls1, ls2, dep - 1) + dfs(rs1, rs2, dep - 1) - 1 + mod) % mod;
res = (res + ((ll)pw[sz[ls1]] - 1 + mod) * ((ll)pw[sz[rs1]] - 1 + mod) % mod) % mod;
res = (res + ((ll)pw[sz[ls2]] - 1 + mod) * ((ll)pw[sz[rs2]] - 1 + mod) % mod) % mod;
return res;
}
int main()
{
scanf("%d%d", &n, &x);
pw[0] = 1;
for (int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
insert(a[i]);
pw[i] = (pw[i - 1] << 1) % mod;
}
int ans = (dfs(1, 1, 30) - 1 + mod) % mod;
printf("%d\n", ans);
return 0;
}