[CCPC2022 广东] XOR Sum
数位 dp
看到这样求和价值的计算,考虑可不可以交换求和符号或者改变计算方式。
这题中的位运算使我们考虑按位计算贡献,价值可以写成:
\[f(A)=\sum_{i=0}2^i\times c_i\times (k-c_i)
\]
其中 \(c_i\) 表示第 \(i\) 位上为 \(1\) 的 \(a_i\) 数量。
题目第二个要求即 \(f(A)=n\)。考虑从高位到低位计算贡献,类似数位 dp 计算方案数。于是序列中的元素就分为两种:卡了上界和没卡上界的。并且计算到当前位时需要知道低位留下来的余数,使该位最终与 \(n\) 上这一位相同。
设 \(dp(i,j,k)\) 表示考虑完从高到低前 \(i\) 位,此时低位留下的余数为 \(j\),卡了上界的数的数量为 \(k\) 的方案数。
转移看 \(m\) 上第 \(i\) 位上是 \(0\) 还是 \(1\):
如果是 \(0\),那么卡上界的数只能继续卡上界,枚举没卡上界的数中 \(1\) 的个数。
如果是 \(1\),分别枚举卡上界和不卡上界的 \(1\) 的个数。
\(1\) 的位置不固定,所以需要预处理组合数。
记忆化搜索即可。
分析当前位上余数最多是多少,如果余数 \(cnt\) 满足 \((cnt-81)\times 2\ge cnt\),那么低位上不存在一种方案使得出现这样的余数。
复杂度 \(O(50\times162\times18\times18\times 18)\),远小于此数。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back
using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 50, M = 170, K = 19, mod = 1e9 + 7;
i64 n, m, k;
i64 f[N][M][K], a[N], c[K][K];
i64 dfs(int dep, int left, int cnt) {
if(left >= 162) return 0;
if(dep == -1) return left == 0;
if(f[dep][left][cnt] != -1) return f[dep][left][cnt];
int dig = (!dep ? 0 : ((n >> (dep - 1)) & 1));
i64 ans = 0;
if(!a[dep]) {
for(int i = 0; i <= k - cnt; i++) {
int cur = left - 1LL * i * (k - i);
if(cur < 0) continue;
ans = (ans + c[k - cnt][i] * dfs(dep - 1, (cur << 1) | dig, cnt) % mod) % mod;
}
} else {
for(int i = 0; i <= cnt; i++) {
for(int j = 0; j <= k - cnt; j++) {
int cur = left - 1LL * (i + j) * (k - j - i);
if(cur < 0) continue;
ans = (ans + c[cnt][i] * c[k - cnt][j] % mod * dfs(dep - 1, (cur << 1) | dig, i) % mod) % mod;
}
}
}
f[dep][left][cnt] = ans;
return ans;
}
int solve() {
if(k == 1) return !n;
if(!m) return !n;
memset(f, -1, sizeof(f));
i64 l = 0, left = 0;
while(m) {
a[l++] = m % 2;
m >>= 1;
}
for(int i = l; i <= 50; i++) {
if((n >> i) & 1) {
left += (1LL << (i - l));
}
}
if(left >= 162) return 0;
return dfs(l, left, k);
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin >> n >> m >> k;
c[0][0] = 1;
for(int i = 1; i <= k; i++) {
c[i][0] = 1;
for(int j = 1; j <= i; j++) {
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
}
std::cout << solve() << "\n";
return 0;
}