[2022天梯赛] 教科书般的亵渎 【记忆化搜索】【剪枝】
题目描述:
$n$张牌每个牌有权值$a_i$,要求选择$k$次,每次让牌的权值减一,使得牌的权值形成从$1$开始的连续整数(不含$0$).
$n,k,ai \leq 50$
分析:
先考虑朴素dp,先将$a_i$排序,$dp[i][S][j]$表示前$i$个数,把$S$这些位填上了,还剩$j$次行动机会的方案数。有
$$dp[i][S][j] -> dp[i+1][S|(1<<k)][j-(a_i-k+1)]$$
考虑剪枝,如果$S$加上后面还没用的$a_i$在操作数最小且形成连续整数的情况下仍然比能用的次数$j$要大,则剪去。用unorderedmap存状态。
代码:
1 #include<bits/stdc++.h> 2 using namespace std; 3 4 const int maxn = 60; 5 const int mod = 998244353; 6 7 int n,k; 8 int a[maxn]; 9 10 int C[maxn][maxn]; 11 int ans; 12 unordered_map <unsigned long long,int> mp[2][52],imm; 13 struct node{ 14 int kk; 15 unsigned long long now; 16 }; 17 queue<node> q[2]; 18 19 bool pd(node fa,int i){ 20 int dj = 0; 21 for(;i<=n;i++){ 22 unsigned long long pj = (fa.now+1)&(~fa.now); 23 if((1ull<<a[i]-1)<pj) continue; 24 int pp = imm[pj]; 25 dj += a[i]-pp; 26 if(dj > fa.kk) return 0; 27 fa.now += pj; 28 } 29 return 1; 30 } 31 32 int fast_pow(int now,int pw){ 33 int ans = 1,dt = 1; 34 while(dt <= pw){ 35 if(dt & pw) ans = 1ll*ans*now%mod; 36 dt<<=1; 37 now = 1ll*now*now%mod; 38 } 39 return ans; 40 } 41 42 int cnt = 0; 43 void dfs(int now,unsigned long long st,int num,int way){ 44 if(now > n){ 45 if(num != 0) return; 46 int flag = imm.count(st+1); 47 if(!flag) return; 48 ans += way; 49 if(ans >= mod) ans -= mod; 50 }else{ 51 for(int i=a[now];i>=1;i--){ 52 if(a[now]-i > num) break; 53 unsigned long long bt = st|(1ull<<i-1); 54 int sz = num-(a[now]-i); 55 int tms = 1ll*way*C[num][a[now]-i]%mod; 56 int zz = (now+1)&1; 57 if(now == n){dfs(now+1,bt,sz,tms);continue;} 58 if(!pd((node){sz,bt},now+1)) continue; 59 if(mp[zz][sz].count(bt)){ 60 mp[zz][sz][bt]=(mp[zz][sz][bt]+tms)%mod; 61 }else { 62 q[zz].push((node){sz,bt}); 63 mp[zz][sz][bt] = tms; 64 } 65 } 66 } 67 } 68 69 70 int main(){ 71 ios::sync_with_stdio(false); 72 cin >> n >> k; 73 for(int i=0;i<=50;i++){ 74 unsigned long long u = 1ull<<i; 75 imm[u] = i+1; 76 } 77 for(int i=0;i<=k;i++){ 78 C[i][0] = C[i][i] = 1; 79 for(int j=1;j<i;j++) C[i][j] = (C[i-1][j-1]+C[i-1][j])%mod; 80 } 81 for(int i=1;i<=n;i++) cin >> a[i]; 82 sort(a+1,a+n+1); 83 mp[1][k][0] = 1; q[1].push((node){k,0}); 84 for(int i=1;i<=n;i++){ 85 for(int j=0;j<=k;j++) mp[(i&1)^1][j].clear(); 86 //cnt=0; 87 while(!q[i&1].empty()){ 88 //cnt++; 89 node kd = q[i&1].front(); q[i&1].pop(); 90 dfs(i,kd.now,kd.kk,mp[i&1][kd.kk][kd.now]); 91 } 92 //cout<<cnt<<endl; 93 } 94 cout<<1ll*ans*fast_pow(fast_pow(n,k),mod-2)%mod<<endl; 95 return 0; 96 }