[题解] XOR Problem
题目大意
对于一个整数序列 \(a_{0...5}\),我们定义它的价值为:
\(f(a)=max(|a_0-a_3|,|a_1-a_4|,|a_2-a_5|)\oplus a_0 \oplus a_1 \oplus a_2 \oplus a_3 \oplus a_4 \oplus a_5\)
其中 \(\oplus\) 是异或操作。
现在给定序列 \(b_{0...5}\),你需要求对于所有满足 \(\forall i,0\leq a_i\leq b_i\) 的序列 \(a\) 的 \(f(a)\) 之和。
由于答案可能很大,你只需要输出答案对 \(2^{64}\) 取模后的值。
对于所有数据,有 \(0\leq b_i\leq 30000\)
时间限制:3 s
空间限制:512 MB
解题思路
是个不太难的题,然而我不会最后一档分。
考场做法是 \(f_{i,j}\) 表示 \(\max\) 不大于 \(i\),\(6\) 个数的异或值为 \(j\) 的方案数。
转移前缀和优化一下,再用奇怪的技巧加个 fwt 就可以有 62pts 了。
考虑如何把都是 \(O(b^2)\) 的时空复杂度降下去,一个套路的做法是把异或值每一位拆开计算。
\(f_{i,j,0/1}\) 表示 \(\max\) 不大于 \(i\),第 \(j\) 位为 \(0/1\) 的方案数,对于每一对分开求,再用 \(f\) dp 就不难了。
那 \(f\) 怎么求 ? 想到差一定是可以卷积的 (将数组翻转),把每一组,每一位,每一种 \(0/1\) 情况分开卷。
接下来的 dp 就和之前的暴力 dp 差不多了。
#include <bits/stdc++.h>
using namespace std;
const int B(30005), BIT(16);
int b[6];
int f[3][BIT][B][2];
inline void read(int &x ){
x = 0; int f = 1, c = getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)) x = x * 10 + c - 48, c = getchar();
x *= f;
}
namespace prep{
const int mod(998244353), G(3), Gi(332748118);
int n;
int a[1 << 16], b[1 << 16], c[1 << 16], rev[1 << 16];
inline void MOD(int &x){ x = x + ((x >> 31) & mod); }
inline int qpow(int x, int a){
int sum = 1; while(a){
if(a & 1) sum = 1LL * sum * x % mod;
x = 1LL * x * x % mod, a >>= 1;
} return sum;
}
void init(int len){
n = 1; while(n < len) n <<= 1;
for(int i(0); i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | (i & 1) * (n >> 1);
}
void NTT(int *a, int n, int op){
int g = op == 1 ? G : Gi, inv = qpow(n, mod - 2);
for(int i(0); i < n; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int len(2), hf(1); len <= n; len <<= 1, hf <<= 1){
int wn = qpow(g, (mod - 1) / len);
for(int bs(0); bs < n; bs += len){
int w = 1;
for(int p(0); p < hf; ++p, w = 1LL * w * wn % mod){
int x = a[bs + p], y = 1LL * w * a[bs + hf + p] % mod;
MOD(a[bs + p] = x + y - mod);
MOD(a[bs + hf + p] = x - y);
}
}
}
if(op == -1) for(int i(0); i < n; ++i) a[i] = 1LL * a[i] * inv % mod;
}
void work(int t, int bit, int la, int lb, int x, int y){
memset(a, 0, sizeof a);
memset(b, 0, sizeof b);
memset(c, 0, sizeof c);
/* 将这一位符合枚举要求的取出来卷积 */
for(int i(0); i <= la; ++i) a[i] = ((i >> bit) & 1) == x;
for(int i(0); i <= lb; ++i) b[i] = ((i >> bit) & 1) == y;
/* 因为差是定值所以翻转一下 */
reverse(b, b + lb + 1);
init(la + lb + 2);
NTT(a, n, 1), NTT(b, n, 1);
for(int i(0); i < n; ++i) c[i] = 1LL * a[i] * b[i] % mod;
NTT(c, n, -1);
for(int i(0); i <= la; ++i) MOD(f[t][bit][i][x ^ y] += c[i + lb] - mod);
for(int i(0); i <= lb; ++i) MOD(f[t][bit][lb - i][x ^ y] += c[i] - mod);
}
void calc(int t, int bit, int la, int lb){
for(int x(0); x <= 1; ++x)
for(int y(0); y <= 1; ++y)
work(t, bit, la, lb, x, y);
/* 两数相等的时候算重了一次 */
f[t][bit][0][0] = 1LL * f[t][bit][0][0] * qpow(2, mod - 2) % mod;
f[t][bit][0][1] = 1LL * f[t][bit][0][1] * qpow(2, mod - 2) % mod;
}
}
namespace DP{
#define ULL unsigned long long
ULL ans, pre[4][B][2], dp[4][B][2];
void solve(int bit){
memset(dp, 0, sizeof(dp));
memset(pre, 0, sizeof(dp));
for(int t(1); t <= 3; ++t)
for(int i(0); i < B; ++i)
for(int j(0); j <= 1; ++j)
pre[t][i][j] = f[t - 1][bit][i][j] + (i ? pre[t][i - 1][j] : 0);
for(int j(0); j < B; ++j) dp[0][j][0] = 1;
for(int i(1); i <= 3; ++i)
for(int j(0); j < B; ++j){
for(int x(0); x <= 1; ++x)
for(int y(0); y <= 1; ++y)
dp[i][j][x ^ y] += dp[i - 1][j][x] * pre[i][j][y];
}
for(int i(B - 1); i; --i)
for(int j(0); j <= 1; ++j)
dp[3][i][j] -= dp[3][i - 1][j];
for(int i(0); i < B; ++i) for(int j(0); j <= 1; ++j)
if(((i >> bit) & 1) ^ j) ans += dp[3][i][j] * (1ULL << bit);
}
}
int main(){
// freopen("C.in", "r", stdin);
// freopen("C.out", "w", stdout);
for(int i(0); i < 6; ++i) read(b[i]);
for(int i(0); i < 16; ++i)
for(int j(0); j < 3; ++j)
prep :: calc(j, i, b[j], b[j + 3]);
for(int i(0); i < 16; ++i) DP :: solve(i);
cout << DP :: ans << endl;
return 0;
}