XJOI 3866 写什么名字好呢
题意
给你一个数组\(R\),包含\(N\)个元素,求有多少满足条件的序列A使得
输出答案对\(1e9+9\)取模
输入格式
第一行输入一个整数\(N (2≤N≤10)\)
第二行输入\(N\) 个整数 \(R[i] (1≤R[i]≤1e18)\)
输出格式
输出一个整数
样例输入&输出
样例1
2
3 5
15
样例2
3
3 3 3
16
样例3
2
1 128
194
样例4
4
26 74 25 30
8409
样例5
2
1000000000 1000000000
420352509
分析
这道题目一定有很大价值。我想了一天
首先,暴力dfs肯定不行,看到\(0≤A[i]≤R[i]\),联想到最近学的数位dp。
如果不会数位dp或不太熟练,建议先学一学或巩固一下,否则接下来可能会不能理解
由$$A[0]+A[1]+...+A[N−1]=A[0]or A[1]...or A[N−1]$$这个条件,我们容易推出它的充分条件是
设 \(A_i\) 的二进制第 \(j\) 位为\(A_{i,j}\) ,则对于某一位 \(j\) ,有\(\sum_{i=1}^{n} A_{i,j}=1\) 或 \(0\) ,即在第 \(j\) 位上,只存在一个 \(A_i\) 或不存在任何一个 \(A_i\) ,使得 \(A_{i,j}\) 为 \(1\) 。
由此我们dp,设 \(dp[j][i][k](i≠0)\) 表示在第 \(j\) 位上,\(A_{i,j}\) 都等于 \(1\),其余的\(A_{i,j'}(j' \in [1,n],j' ≠0)\) 都为 \(0\) , \(k\) 为状态压缩,即用二进制压掉; \(dp[j][0][k]\) 表示在第 \(j\) 位上,所有的 \(A_{i,j}\) 都为 \(0\) 。
上面为什么要状压?因为在数位dp中,对于当前dfs到的数要用一个limit记录它是否受最高位限制,如果对于每一个 \(A_i\) 都开一个limit,那代码会变得十分冗杂。用二进制不仅能简化代码,还能优化掉空间复杂度\(\frac{1}{8}\)的常数。由于 \(n\le 10\) ,所以初始化limit变量为1023即可。别忘了把dp数组初始化为 \(-1\) 。
Code
#include<cstdio>
#define maxn 12
#define maxw 70
#define mod 1000000009 //别写错
#define get(x,pos) bool((x)&(1ll<<pos)) //返回x的二进制位的第pos+1个数字
#define set0(x,pos) ((x)&(1023-(1<<pos))) //把x的第pos+1个二进制位设为0
using namespace std;
int lg(long long x){ //log2(懒得调用cmath)
int ret=-1;
while(x){
ret++;
x>>=1;
}
return ret;
}
long long a[maxn],dp[maxw][maxn][1030],ma; //记得开long long
int n;
long long dfs(int pos,int limit){ //数位dp的记搜写法(循环写法我不太会)
if(pos<0){
return 1;
}
long long ret=0,ans=0;
int up,tmp=limit;
for(int i=1;i<=n;i++){
if(dp[pos][i][limit]!=-1){
ret=(ret+dp[pos][i][limit])%mod;
}
else{
up=get(limit,i-1)?get(a[i],pos):1;
ans=0;
if(up){
for(int j=1;j<=n;j++){
if(j!=i){
if(!(get(limit,j-1)&&!get(a[j],pos)))limit=set0(limit,j-1);
}
}
ans=dfs(pos-1,limit);
limit=tmp;
}
dp[pos][i][limit]=ans;
ret=(ret+ans)%mod;
}
}
if(dp[pos][0][limit]!=-1){
ret=(ret+dp[pos][0][limit])%mod;
}
else{
for(int j=1;j<=n;j++){
if(!(get(limit,j-1)&&!get(a[j],pos)))limit=set0(limit,j-1);
}
ans=dfs(pos-1,limit);
limit=tmp;
dp[pos][0][limit]=ans;
ret=(ret+ans)%mod;
}
return ret;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
if(a[i]>ma)ma=a[i];
}
ma=lg(ma);
for(int i=0;i<=ma;i++){
for(int j=0;j<=n;j++){
for(int k=0;k<1024;k++){
dp[i][j][k]=-1;
}
}
}
printf("%lld\n",dfs(ma,1023));
return 0;
}