Codeforces 1431J. Zero-XOR Array (3400)
题目描述
给定一个长度为 \(n\) 的序列 \(a\),需要计算满足下列条件的序列 \(b\) 的个数,答案对 \(998244353\) 取模。
- 序列 \(b\) 的长度为 \(2n-1\);
- \(b_{2i-1}=a_i(i\in [1,n])\);
- 序列 \(b\) 单调不减,即 \(b_1\le b_2\le \cdots\le b_{2n-1}\);
- \(b_1\oplus b_2\oplus\cdots \oplus\ b_{2n-1}=0\)。
\(1\le n\le 18,1\le m\le 60,b_i\in [0,2^m)\)。
原题只能用 \(\text{Kotlin}\) 语言提交,可以在 这里 提交其他语言。
首先将非 \(a_i\) 的元素都提取出来构成一个新的长度为 \(n-1\) 的序列 \(b\),这样题意转化成求长度为 \(n-1\) 的序列 \(b\),使得 \(a_i\le b_i\le a_{i+1}\),且满足 \(b_1\oplus b_2\oplus\cdots\oplus b_{n-1}=a_1\oplus a_2\oplus\cdots\oplus a_{n}=C\)。
由于位与位之间独立,可以考虑每一位的贡献。
观察对于任意 \(b_i\) 的某一位 \(t\),有 \(3\) 种情况\((\)其中第 \(m-1\) 位是最高位\()\)。
- \(b_i\) 的第 \(t\) 位到第 \(m-1\) 位组成的前缀与 \(a_i\) 相同。
- \(b_i\) 的第 \(t\) 位到第 \(m-1\) 位组成的前缀与 \(a_i\) 不同,与 \(a_{i+1}\) 相同。
- \(b_i\) 的第 \(t\) 位到第 \(m-1\) 位组成的前缀既与 \(a_i\) 不同,也与 \(a_{i+1}\) 不同。
也就是考虑 \(b_i\) 的每一段前缀是否达到上界、下界、两界都没达到\((\)在中间\()\)。
对于达到上下界的两种情况显然可以比较容易记录,关键是两界都没达到的情况。
进一步观察,发现若存在 \(b_i\) 的某一位 \(t\) 到第 \(m-1\) 位的前缀两界都没达到,且固定了所有 \(b_i\) 的 \(j\) 到 \(m-1\) 位,记 \(\min_{i}\) 为 \(b_i\) 可以达到的最小值,\(\max_{i}\) 为 \(b_i\) 可以达到的最大值,那么方案数为 \(\prod\limits_{k\neq i}(\max_{k}-\min_{k}+1)\)。
因为若确定了其他所有 \(b_i\) 在 \(j\) 位以后的取值,那么由于异或值固定,所以当前 \(b_i\) 的 \(j\) 位后的取值也必定固定。
于是可以枚举第一次出现两界都没达到的位数 \(j\) 以及在第 \(j\) 位以前的每个 \(b_i\) 的前缀达到的上\(/\)下界,设 \(dp_{i,0/1/2/3}\) 表示确定了前 \(i\) 个数的 \(j\sim m-1\) 位,第 \(j\) 位的异或和为 \(0/1\) 以及是\(/\)否出现了两界都没达到的情况。然后分情况转移,具体状态方程见代码实现。
时间复杂度 \(O(nm2^n)\)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 19, mod = 998244353;
int n, m, ans, pos[N]; ll sum, l[N], r[N], dp[N][4];
inline bool check(int s, int k) {
ll ss = 0;
for (int i = 1; i < n; ++ i) {
if (s >> (i - 1) & 1) {
if (pos[i] < k) return 0;
ss ^= r[i];
} else ss ^= l[i];
} return (ss >> k) == (sum >> k);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++ i) scanf("%lld", l + i), r[i - 1] = l[i], sum ^= l[i];
for (int i = 1; i < n; ++ i) {
pos[i] = -1;
for (int j = m - 1; ~j; -- j) if ((l[i] >> j & 1) ^ (r[i] >> j & 1)) {
pos[i] = j; break;
}
}
for (int i = 0; i < 1 << n - 1; ++ i) if (check(i, 0)) ++ ans;
for (int i = 0; i < m; ++ i) for (int j = 0; j < 1LL << n - 1; ++ j) if (check(j, i + 1)) {
memset(dp, 0, sizeof dp), dp[0][0] = 1;
for (int p = 1; p < n; ++ p) {
if (pos[p] < i) {
int d = l[p] >> i & 1, num = (r[p] - l[p] + 1) % mod;
for (int t = 0; t < 4; ++ t)
(dp[p][t] += dp[p - 1][t ^ d] * num) %= mod;
} else if (pos[p] == i) {
ll lim = l[p] | ((1LL << i) - 1);
int n0 = (lim - l[p] + 1) % mod, n1 = (r[p] - lim) % mod;
for (int t = 0; t < 4; ++ t) {
(dp[p][t] += dp[p - 1][t] * n0) %= mod;
(dp[p][t] += dp[p - 1][t ^ 1] * n1) %= mod;
}
} else if (~j >> (p - 1) & 1) {
if (l[p] >> i & 1) {
int num = ((1LL << i) - l[p] % (1LL << i)) % mod;
for (int t = 0; t < 4; ++ t)
(dp[p][t] += dp[p - 1][t ^ 1] * num) %= mod;
} else {
int n0 = ((1LL << i) - l[p] % (1LL << i)) % mod, n1 = (1LL << i) % mod;
for (int t = 0; t < 4; ++ t)
(dp[p][t] += dp[p - 1][t] * n0) %= mod;
for (int t = 2; t < 4; ++ t)
(dp[p][t] += dp[p - 1][t ^ 1] * n1 + dp[p - 1][(t ^ 1) & 1]) %= mod;
}
} else {
if (r[p] >> i & 1) {
ll lim = r[p] >> i << i;
int n0 = (1LL << i) % mod, n1 = (r[p] - lim + 1) % mod;
for (int t = 0; t < 4; ++ t)
(dp[p][t] += dp[p - 1][t ^ 1] * n1) %= mod;
for (int t = 2; t < 4; ++ t)
(dp[p][t] += dp[p - 1][t] * n0 + dp[p - 1][t & 1]) %= mod;
} else {
int num = (r[p] - (r[p] >> i << i) + 1) % mod;
for (int t = 0; t < 4; ++ t)
(dp[p][t] += dp[p - 1][t] * num) %= mod;
}
}
}
(ans += dp[n - 1][(sum >> i & 1) | 2]) %= mod;
}
return printf("%d\n", ans), 0;
}