HDU - 5977 Garden of Eden (树形dp+容斥)
题意:一棵树上有n(n<=50000)个结点,结点有k(k<=10)种颜色,问树上总共有多少条包含所有颜色的路径。
我最初的想法是树形状压dp,设dp[u][S]为以结点u为根的包含颜色集合为S的路径条数,然后FWT(应该叫FMT?)搞一下就行了,复杂度$O(nk2^k)$。奈何内存太大,妥妥地MLE...
看到网上大部分的解法都是点分治,我不禁联想到之前学过的树上任意两点距离的求法(点分治+FFT),心想,这道题用点分治+FWT是不是也能过?于是比着葫芦画瓢写出了这样一段又臭又长的代码:
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef double db; 5 const int N=5e4+10,inf=0x3f3f3f3f; 6 int n,k,a[N],hd[N],ne,vis[N],K,siz[N],tot,rt,mx; 7 ll dp[1<<10],ans; 8 struct E {int v,nxt;} e[N<<1]; 9 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 10 void FWT(ll* a,int n,int f) { 11 for(int k=1; k<n; k<<=1) 12 for(int i=0; i<n; i+=k<<1) 13 for(int j=i; j<i+k; ++j) 14 a[j+k]+=f==1?a[j]:-a[j]; 15 } 16 void getroot(int u,int fa) { 17 siz[u]=1; 18 int sz=0; 19 for(int i=hd[u]; ~i; i=e[i].nxt) { 20 int v=e[i].v; 21 if(vis[v]||v==fa)continue; 22 getroot(v,u); 23 siz[u]+=siz[v]; 24 sz=max(sz,siz[v]); 25 } 26 sz=max(sz,tot-siz[u]); 27 if(sz<mx)mx=sz,rt=u; 28 } 29 void dfs(int u,int fa,int S) { 30 dp[S]++; 31 for(int i=hd[u]; ~i; i=e[i].nxt) { 32 int v=e[i].v; 33 if(vis[v]||v==fa)continue; 34 dfs(v,u,S|a[v]); 35 } 36 } 37 void cal(int u,int ba,int f) { 38 for(int i=0; i<=K; ++i)dp[i]=0; 39 dfs(u,-1,a[u]|ba); 40 FWT(dp,K+1,1); 41 for(int i=0; i<=K; ++i)dp[i]*=dp[i]; 42 FWT(dp,K+1,-1); 43 ans+=dp[K]*f; 44 } 45 void solve(int u) { 46 mx=inf,getroot(u,-1),u=rt,cal(u,0,1),vis[u]=1; 47 for(int i=hd[u]; ~i; i=e[i].nxt) { 48 int v=e[i].v; 49 if(!vis[v])tot=siz[v],cal(v,a[u],-1),solve(v); 50 } 51 } 52 ll treepartion() { 53 ans=0,tot=n; 54 solve(1); 55 return ans; 56 } 57 int main() { 58 while(scanf("%d%d",&n,&k)==2) { 59 memset(hd,-1,sizeof hd),ne=0; 60 memset(vis,0,sizeof vis); 61 K=(1<<k)-1; 62 for(int i=1; i<=n; ++i)scanf("%d",&a[i]),a[i]=1<<(a[i]-1); 63 for(int i=1; i<n; ++i) { 64 int u,v; 65 scanf("%d%d",&u,&v); 66 addedge(u,v); 67 addedge(v,u); 68 } 69 printf("%lld\n",treepartion()); 70 } 71 return 0; 72 }
虽然成功地AC了,但是仔细一想:不对啊,这道题FWT的复杂度和子树的大小不是线性相关的啊!所以这样一来,总的复杂度成了$O(nk2^klogn)$,反而增大了。
也就是说,这道题用点分治的作用仅仅是减少了内存的开销,复杂度非但没有减少,反而还多了个logn!
当然除了点分治,这道题还有其他的优化方法,比如sclbgw7大佬利用树链剖分的思想将内存优化到了$O(2^klogn)$,时间复杂度仍为$O(nk2^k)$。
树剖做法:
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=5e4+10,inf=0x3f3f3f3f; 5 int n,k,a[N],hd[N],ne,S,fa[N],son[N],siz[N],tot; 6 ll dp[17][1<<10],b[1<<10],A[N],B[N],ans; 7 struct E {int v,nxt;} e[N<<1]; 8 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 9 void FWT(ll* a,int n,int f) { 10 for(int k=1; k<n; k<<=1) 11 for(int i=0; i<n; i+=k<<1) 12 for(int j=i; j<i+k; ++j) { 13 ll x=a[j],y=a[j+k]; 14 a[j+k]=f==1?y+x:y-x; 15 } 16 } 17 void mul(ll* a,ll* b,ll* c,int n) { 18 for(int i=0; i<n; ++i)A[i]=a[i],B[i]=b[i]; 19 FWT(A,n,1),FWT(B,n,1); 20 for(int i=0; i<n; ++i)c[i]=A[i]*B[i]; 21 FWT(c,n,-1); 22 } 23 int newnode() { 24 int u=tot++; 25 for(int i=0; i<(1<<k); ++i)dp[u][i]=0; 26 return u; 27 } 28 void dfs1(int u,int f) { 29 fa[u]=f,son[u]=0,siz[u]=1; 30 for(int i=hd[u]; ~i; i=e[i].nxt) { 31 int v=e[i].v; 32 if(v==fa[u])continue; 33 dfs1(v,u),siz[u]+=siz[v]; 34 if(siz[v]>siz[son[u]])son[u]=v; 35 } 36 } 37 void dfs2(int u,int w) { 38 if(son[u])dfs2(son[u],w); 39 for(int i=0; i<(1<<k); ++i)b[i]=0; 40 b[1<<a[u]]=1; 41 if((1<<a[u]==S))ans++; 42 for(int i=0; i<(1<<k); ++i)if((i|(1<<a[u]))==S)ans+=dp[w][i]*2; 43 for(int i=0; i<(1<<k); ++i)b[i|(1<<a[u])]+=dp[w][i]; 44 for(int i=0; i<(1<<k); ++i)dp[w][i]=b[i]; 45 for(int i=hd[u]; ~i; i=e[i].nxt) { 46 int v=e[i].v; 47 if(v==fa[u]||v==son[u])continue; 48 int wv=newnode(); 49 dfs2(v,wv),tot--; 50 mul(dp[w],dp[wv],b,1<<k); 51 ans+=b[S]*2; 52 for(int i=0; i<(1<<k); ++i)dp[w][i|(1<<a[u])]+=dp[wv][i]; 53 } 54 } 55 int main() { 56 while(scanf("%d%d",&n,&k)==2) { 57 memset(hd,-1,sizeof hd),ne=0; 58 S=(1<<k)-1; 59 for(int i=1; i<=n; ++i)scanf("%d",&a[i]),a[i]--; 60 for(int i=1; i<n; ++i) { 61 int u,v; 62 scanf("%d%d",&u,&v); 63 addedge(u,v); 64 addedge(v,u); 65 } 66 tot=ans=0; 67 dfs1(1,-1),dfs2(1,newnode()); 68 printf("%lld\n",ans); 69 } 70 return 0; 71 }
还有Menhera大佬利用基的FMT性质将时间复杂度优化到$O(n2^k)$的做法,看样子有点像是容斥。我个人更倾向于这一种,于是在这个思想的基础上进一步地分析:
题目要求的是包含所有颜色的路径条数。如果包含某个元素集合的路径不太好求,那么不包含某个元素集合的呢?只要把属于这个集合的结点都染成白色,不属于的都染成黑色,则问题就转化成了求一棵树上包含的所有点都是黑色的路径条数,直接dp求一下就行了。于是我们可以利用容斥原理,用所有的路径数减去不包含1个元素集合的路径数,再加上不包含2个元素集合的路径数,再减去不包含3个...就得到了答案。
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=5e4+10,inf=0x3f3f3f3f; 5 int n,k,a[N],siz[N],hd[N],ne,ppc[1<<10]; 6 ll ans; 7 struct E {int v,nxt;} e[N<<1]; 8 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 9 void dfs(int u,int fa,int f) { 10 for(int i=hd[u]; ~i; i=e[i].nxt)if(e[i].v!=fa)dfs(e[i].v,u,f); 11 if(!siz[u])return; 12 ans+=f; 13 for(int i=hd[u]; ~i; i=e[i].nxt)if(e[i].v!=fa) { 14 int v=e[i].v; 15 ans+=(ll)siz[v]*siz[u]*2*f; 16 siz[u]+=siz[v]; 17 } 18 } 19 ll solve() { 20 ans=0; 21 for(int S=(1<<k)-1; S; --S) { 22 int f=(k-ppc[S])&1?-1:1; 23 for(int i=1; i<=n; ++i)siz[i]=S>>a[i]&1; 24 dfs(1,-1,f); 25 } 26 return ans; 27 } 28 int main() { 29 ppc[0]=0; 30 for(int i=1; i<(1<<10); ++i)ppc[i]=ppc[i>>1]+(i&1); 31 while(scanf("%d%d",&n,&k)==2) { 32 memset(hd,-1,sizeof hd),ne=0; 33 for(int i=1; i<=n; ++i)scanf("%d",&a[i]),a[i]--; 34 for(int i=1; i<n; ++i) { 35 int u,v; 36 scanf("%d%d",&u,&v); 37 addedge(u,v); 38 addedge(v,u); 39 } 40 printf("%lld\n",solve()); 41 } 42 return 0; 43 }
这种方法的复杂度为什么会比FWT少了k呢?这个k哪里去了呢?我想大概是在FWT的过程中把所有集合的dp值都求出来了,而我们只需要求全集的dp值,因此多做了许多无用功。
(ps:由于题目数据的限制,用map优化的点分治可能会更快一些)