[2022CCPC广州] XOR Sum

这个肯定要分二进制位来做了,所以考虑数位dp,先看要放哪些东西进dp状态:现在处理到第几位了,有多少个数现在是顶着上界的,还差多少值和才能到 \(n\)。现在看怎么转移。\(k\le 18\) ,可以直接枚举这一位上有多少个 \(1\) ,当第 \(t\) 位有 \(x\)\(1\) 时,对和的贡献是 \(2^{t} \times x \times(k-x)\) ,这是显然的。对于 \(m\) 的第 \(t\) 位是否为 \(1\),要进行分类讨论,如果是,就可以从顶着上界的数里面选 \(1\) ,然后被选为\(1\) 的继续顶着上界。如果\(m\) 的第 \(t\) 位不为为 \(1\),则只能从没有顶着上界的数里面选 \(1\)。两种情况的答案都要乘组合数。但是直接这样写是要 TLE 的,加入剪枝:如果后面所有位都按照最优的情况来选(即为 \(1\)\(0\) 各一半)和还是小于 \(n\) 就直接返回 \(0\) ,就是这两行:

	int s1=k/2,s0=k-s0;
	if(((1ll<<t+1)-1)*s0*s1<tot)return 0;

时间复杂度证明可以看这里

完整代码:

#include <bits/stdc++.h>
#define int long long
#define ll long long
#define ull unsigned long long
#pragma GCC optimeze(3)
#pragma GCC optimeze(2)
#define PII pair<int, int>
#define pb push_back
#define fi first
#define se second
#define lowbit(x) (x & (-x))
#define inv(x) (qpow(x,mod-2))
#define lwz lower_bound
#define blong(i) ((i+K-1)/K)
using namespace std;
const int N=2e3+5;
const int M=3e2+5;
const int mod=1e9+7;
double eps=1e-6;
inline int read(){
	char ch=getchar();bool f=0;int x=0;
	for(;!isdigit(ch);ch=getchar())if(ch=='-')f=1;
	for(;isdigit(ch);ch=getchar())x=(x<<1)+(x<<3)+(ch^48);
	if(f==1)x=-x;return x;
}
ll qpow(ll a,ll b){
	ll ans=1;
	while(b){
		if(b&1)ans*=a,ans%=mod;
		a*=a,a%=mod,b>>=1;
	}
	return ans;
}
int gcd(int a,int b){return b==0? a:gcd(b,a%b);}
void add(int&a,int b){a+=b;if(a>=mod)a-=mod;}
void minus(int&a,int b){a-=b;if(a<0)a+=mod;}
int n,k,m,a[40],c[40][40];
unordered_map<int,int>dp[40][40];
int dfs(int t,int lim,int tot){
	if(t<0)return tot==0;
	if(tot<0)return 0;
	int s1=k/2,s0=k-s0;
	if(((1ll<<t+1)-1)*s0*s1<tot)return 0;
	if(dp[t][lim].count(tot))return dp[t][lim][tot];
	int ret=0;
	if(a[t]==1){
		for(int i=0;i<=lim;i++){
			for(int j=0;j<=k-lim;j++){
				int val=(1ll<<t)*(i+j)*(k-i-j);
				add(ret,dfs(t-1,i,tot-val)*c[lim][i]%mod*c[k-lim][j]%mod);
			}
		}
	}
	else{
		for(int i=0;i<=k-lim;i++){
			int val=(1ll<<t)*i*(k-i);
			add(ret,dfs(t-1,lim,tot-val)*c[k-lim][i]%mod);
		}
	}
	dp[t][lim][tot]=ret;
	return ret;
}
signed main(){ 
	ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
	cin>>n>>m>>k;
	for(int i=0;i<=39;i++){
		a[i]=((m>>i)&1);
	}
	for(int i=0;i<=k;i++){
		c[i][0]=1;
		for(int j=1;j<=i;j++){
			c[i][j]=(c[i-1][j]+c[i-1][j-1])%mod;
		}
	}
	cout<<dfs(39,k,n);
	return 0;
} 	
posted @ 2025-02-20 21:55  Xdik  阅读(44)  评论(1)    收藏  举报