「postOI」Lost Array
题意
有一个序列 \(A=\{a_1, a_2, ..., a_n\}\),按如下方式构造一个 \((n + 1) \times (n + 1)\) 的矩阵 \(B\):
- \(B_{i0}=0\)(\(0\le i\le n\));
- \(B_{0i} = a_i\)(\(1 \le i \le n\));
- \(B_{ij} = B_{(i - 1)j} \text{ xor } B_{i(j - 1)}\)(\(1 \le i, j \le n\))。
现在给出 \(B_{1n}, B_{2n}, ..., B_{nn}\)(也就是最后一列,但是没有 \(B_{0n}\)),求出 \(A\)。
\(n \le 5 \times 10^5\)
解析
题目给出的是 \(B\) 的递推式,我们希望得到计算式,换句话说,我们希望直接得到 \(B_{in}\) 只与 \(A\) 有关的表达式。
倒过来想,考虑 \(a_i\) 对 \(B_{jn}\) 的贡献。由于是异或,\(a_i\) 只可能贡献 \(a_i\) 或 \(0\)。
那具体贡献多少?我们可以把问题具象化,\(B\) 的递推式相当于是“向左走一步或向上走一步”。那么只需要判断从 \(B_{jn}\) 走到 \(B_{0i}\) 的方案数是奇数还是偶数。这里有一个小细节——“走到 \(B_{0i}\) 就结束了,不能继续走到 \(B_{0(i-1)}\)”,这个细节相当于说最后一步一定是向上的。
这样我们就可以通过组合数算出 \(B_{jn}\) 到 \(B_{0i}\) 的方案数:
这样并不好看,我们设 \(a'_i = a_{n - i}\),\(b_i = B_{i + 1, n}\)。那么 \(a_i'\) 对 \(b_j\) 的贡献只需要看 \(\binom{i + j}{j}\) 的奇偶性。关于组合数的奇偶性,结论如下:
\(\binom ab\) 为奇数当且仅当 \(a\text{ and }b = b\)。
也就是说 \(a_i'\) 对 \(b_j\) 有贡献当且仅当 \((i + j)\text{ and }j=j\) 等价于 \(i\text{ and }j = 0\)。
似乎有一个做法,如果把 \(j\) 取个补集,那条件不就是 \(i\in j\)(\(i\text{ and }j = i\)),那么
这 \(b'\) 不就是 \(a'\) 做了或卷积 FWT 的结果吗?然而,由于 \(n\) 未必是 \(2\) 的整次方,\(b'_0, b'_1, ..., b_{n - 1}'\) 中有几项我们不知道。这个方法就这么废了……
我们再考虑一下 \(i\text{ and }j\) 能怎么处理——容斥?我们钦定 \(i\) 对应的二进制位全为 \(1\),即 \(i\text{ and }j = i\),记为 \(c_i\):
则由容斥可得下式(容斥的正负系数在异或中没有意义)
唔,看起来好像没什么区别,还更麻烦了?先分析一下,由于 \(c_i\) 是计算 \(i\) 的超集的异或和,那么当 \(i \ge n\) 时,\(c_i = 0\)。于是我们只需要计算 \(c_0, c_1, ..., c_{n - 1}\),那么我们可以通过 \(b\) 计算出这些结果吗?
当然是可以的——因为这是或卷积的 fwt,计算 \(b_i\) 只需要用到 \(j \le i\) 的 \(c_j\),那么就可以通过 \(b_{0~(n-1)}\) 做一遍或卷积 fmt 求出 \(c_{0~(n-1)}\)。
最后再用完整的 \(c\) 做一遍与卷积 fmt 求得 \(a'\)。
源代码
#include <cstdio>
#include <cstring>
#include <algorithm>
const int MAXN = (int)5e5 + 10;
int len, lg2_len;
int arr[MAXN];
void fwtOr()
{
for (int i = 0; i <= lg2_len; ++i)
{
for (int j = 0; j < len; ++j)
{
if (j & (1 << i))
{
arr[j] ^= arr[j ^ (1 << i)];
}
}
}
}
void fwtAnd()
{
for (int i = 0; i <= lg2_len; ++i)
{
for (int j = 0; j < len; ++j)
{
if (j & (1 << i))
{
arr[j ^ (1 << i)] ^= arr[j];
}
}
}
}
int main()
{
scanf("%d", &len);
for (int i = 0; i < len; ++i)
{
scanf("%d", &arr[i]);
}
while ((1 << lg2_len) < len)
{
++lg2_len;
}
fwtOr();
fwtAnd();
for (int i = 1; i < len; ++i)
{
printf("%d ", arr[len - i]);
}
printf("%d\n", arr[0]);
return 0;
}