Loading

[CCPC2022 广东] XOR Sum

数位 dp

看到这样求和价值的计算,考虑可不可以交换求和符号或者改变计算方式。

这题中的位运算使我们考虑按位计算贡献,价值可以写成:

\[f(A)=\sum_{i=0}2^i\times c_i\times (k-c_i) \]

其中 \(c_i\) 表示第 \(i\) 位上为 \(1\)\(a_i\) 数量。

题目第二个要求即 \(f(A)=n\)。考虑从高位到低位计算贡献,类似数位 dp 计算方案数。于是序列中的元素就分为两种:卡了上界和没卡上界的。并且计算到当前位时需要知道低位留下来的余数,使该位最终与 \(n\) 上这一位相同。

\(dp(i,j,k)\) 表示考虑完从高到低前 \(i\) 位,此时低位留下的余数为 \(j\),卡了上界的数的数量为 \(k\) 的方案数。

转移看 \(m\) 上第 \(i\) 位上是 \(0\) 还是 \(1\)

如果是 \(0\),那么卡上界的数只能继续卡上界,枚举没卡上界的数中 \(1\) 的个数。

如果是 \(1\),分别枚举卡上界和不卡上界的 \(1\) 的个数。

\(1\) 的位置不固定,所以需要预处理组合数。

记忆化搜索即可。

分析当前位上余数最多是多少,如果余数 \(cnt\) 满足 \((cnt-81)\times 2\ge cnt\),那么低位上不存在一种方案使得出现这样的余数。

复杂度 \(O(50\times162\times18\times18\times 18)\),远小于此数。

#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back

using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 50, M = 170, K = 19, mod = 1e9 + 7;
i64 n, m, k;
i64 f[N][M][K], a[N], c[K][K];
i64 dfs(int dep, int left, int cnt) {
	if(left >= 162) return 0;
	if(dep == -1) return left == 0;
	if(f[dep][left][cnt] != -1) return f[dep][left][cnt];
	int dig = (!dep ? 0 : ((n >> (dep - 1)) & 1));
	i64 ans = 0;
	if(!a[dep]) {
		for(int i = 0; i <= k - cnt; i++) {
			int cur = left - 1LL * i * (k - i);
			if(cur < 0) continue;
			ans = (ans + c[k - cnt][i] * dfs(dep - 1, (cur << 1) | dig, cnt) % mod) % mod;
		}
	} else {
		for(int i = 0; i <= cnt; i++) {
			for(int j = 0; j <= k - cnt; j++) {
				int cur = left - 1LL * (i + j) * (k - j - i);
				if(cur < 0) continue;
				ans = (ans + c[cnt][i] * c[k - cnt][j] % mod * dfs(dep - 1, (cur << 1) | dig, i) % mod) % mod;
			}
		}
	}
	f[dep][left][cnt] = ans;
	return ans;
}
int solve() {
	if(k == 1) return !n;		
	if(!m) return !n;
	memset(f, -1, sizeof(f));
	i64 l = 0, left = 0;
	while(m) {
		a[l++] = m % 2;
		m >>= 1;
	}
	for(int i = l; i <= 50; i++) {
		if((n >> i) & 1) {
			left += (1LL << (i - l));
		}
	}
	if(left >= 162) return 0;
	return dfs(l, left, k);
}
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
	std::cin >> n >> m >> k;

	c[0][0] = 1;
	for(int i = 1; i <= k; i++) {
		c[i][0] = 1;
		for(int j = 1; j <= i; j++) {
			c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
		}
	}
	std::cout << solve() << "\n";

	return 0;
}
posted @ 2024-07-21 14:57  Fire_Raku  阅读(78)  评论(0编辑  收藏  举报