HDU - 5977 Garden of Eden (树形dp+容斥)

题意:一棵树上有n(n<=50000)个结点,结点有k(k<=10)种颜色,问树上总共有多少条包含所有颜色的路径。

我最初的想法是树形状压dp,设dp[u][S]为以结点u为根的包含颜色集合为S的路径条数,然后FWT(应该叫FMT?)搞一下就行了,复杂度$O(nk2^k)$。奈何内存太大,妥妥地MLE...

看到网上大部分的解法都是点分治,我不禁联想到之前学过的树上任意两点距离的求法(点分治+FFT),心想,这道题用点分治+FWT是不是也能过?于是比着葫芦画瓢写出了这样一段又臭又长的代码:

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 typedef double db;
 5 const int N=5e4+10,inf=0x3f3f3f3f;
 6 int n,k,a[N],hd[N],ne,vis[N],K,siz[N],tot,rt,mx;
 7 ll dp[1<<10],ans;
 8 struct E {int v,nxt;} e[N<<1];
 9 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
10 void FWT(ll* a,int n,int f) {
11     for(int k=1; k<n; k<<=1)
12         for(int i=0; i<n; i+=k<<1)
13             for(int j=i; j<i+k; ++j)
14                 a[j+k]+=f==1?a[j]:-a[j];
15 }
16 void getroot(int u,int fa) {
17     siz[u]=1;
18     int sz=0;
19     for(int i=hd[u]; ~i; i=e[i].nxt) {
20         int v=e[i].v;
21         if(vis[v]||v==fa)continue;
22         getroot(v,u);
23         siz[u]+=siz[v];
24         sz=max(sz,siz[v]);
25     }
26     sz=max(sz,tot-siz[u]);
27     if(sz<mx)mx=sz,rt=u;
28 }
29 void dfs(int u,int fa,int S) {
30     dp[S]++;
31     for(int i=hd[u]; ~i; i=e[i].nxt) {
32         int v=e[i].v;
33         if(vis[v]||v==fa)continue;
34         dfs(v,u,S|a[v]);
35     }
36 }
37 void cal(int u,int ba,int f) {
38     for(int i=0; i<=K; ++i)dp[i]=0;
39     dfs(u,-1,a[u]|ba);
40     FWT(dp,K+1,1);
41     for(int i=0; i<=K; ++i)dp[i]*=dp[i];
42     FWT(dp,K+1,-1);
43     ans+=dp[K]*f;
44 }
45 void solve(int u) {
46     mx=inf,getroot(u,-1),u=rt,cal(u,0,1),vis[u]=1;
47     for(int i=hd[u]; ~i; i=e[i].nxt) {
48         int v=e[i].v;
49         if(!vis[v])tot=siz[v],cal(v,a[u],-1),solve(v);
50     }
51 }
52 ll treepartion() {
53     ans=0,tot=n;
54     solve(1);
55     return ans;
56 }
57 int main() {
58     while(scanf("%d%d",&n,&k)==2) {
59         memset(hd,-1,sizeof hd),ne=0;
60         memset(vis,0,sizeof vis);
61         K=(1<<k)-1;
62         for(int i=1; i<=n; ++i)scanf("%d",&a[i]),a[i]=1<<(a[i]-1);
63         for(int i=1; i<n; ++i) {
64             int u,v;
65             scanf("%d%d",&u,&v);
66             addedge(u,v);
67             addedge(v,u);
68         }
69         printf("%lld\n",treepartion());
70     }
71     return 0;
72 }
View Code

虽然成功地AC了,但是仔细一想:不对啊,这道题FWT的复杂度和子树的大小不是线性相关的啊!所以这样一来,总的复杂度成了$O(nk2^klogn)$,反而增大了。

也就是说,这道题用点分治的作用仅仅是减少了内存的开销,复杂度非但没有减少,反而还多了个logn!

当然除了点分治,这道题还有其他的优化方法,比如sclbgw7大佬利用树链剖分的思想将内存优化到了$O(2^klogn)$,时间复杂度仍为$O(nk2^k)$。

树剖做法:

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int N=5e4+10,inf=0x3f3f3f3f;
 5 int n,k,a[N],hd[N],ne,S,fa[N],son[N],siz[N],tot;
 6 ll dp[17][1<<10],b[1<<10],A[N],B[N],ans;
 7 struct E {int v,nxt;} e[N<<1];
 8 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
 9 void FWT(ll* a,int n,int f) {
10     for(int k=1; k<n; k<<=1)
11         for(int i=0; i<n; i+=k<<1)
12             for(int j=i; j<i+k; ++j) {
13                 ll x=a[j],y=a[j+k];
14                 a[j+k]=f==1?y+x:y-x;
15             }
16 }
17 void mul(ll* a,ll* b,ll* c,int n) {
18     for(int i=0; i<n; ++i)A[i]=a[i],B[i]=b[i];
19     FWT(A,n,1),FWT(B,n,1);
20     for(int i=0; i<n; ++i)c[i]=A[i]*B[i];
21     FWT(c,n,-1);
22 }
23 int newnode() {
24     int u=tot++;
25     for(int i=0; i<(1<<k); ++i)dp[u][i]=0;
26     return u;
27 }
28 void dfs1(int u,int f) {
29     fa[u]=f,son[u]=0,siz[u]=1;
30     for(int i=hd[u]; ~i; i=e[i].nxt) {
31         int v=e[i].v;
32         if(v==fa[u])continue;
33         dfs1(v,u),siz[u]+=siz[v];
34         if(siz[v]>siz[son[u]])son[u]=v;
35     }
36 }
37 void dfs2(int u,int w) {
38     if(son[u])dfs2(son[u],w);
39     for(int i=0; i<(1<<k); ++i)b[i]=0;
40     b[1<<a[u]]=1;
41     if((1<<a[u]==S))ans++;
42     for(int i=0; i<(1<<k); ++i)if((i|(1<<a[u]))==S)ans+=dp[w][i]*2;
43     for(int i=0; i<(1<<k); ++i)b[i|(1<<a[u])]+=dp[w][i];
44     for(int i=0; i<(1<<k); ++i)dp[w][i]=b[i];
45     for(int i=hd[u]; ~i; i=e[i].nxt) {
46         int v=e[i].v;
47         if(v==fa[u]||v==son[u])continue;
48         int wv=newnode();
49         dfs2(v,wv),tot--;
50         mul(dp[w],dp[wv],b,1<<k);
51         ans+=b[S]*2;
52         for(int i=0; i<(1<<k); ++i)dp[w][i|(1<<a[u])]+=dp[wv][i];
53     }
54 }
55 int main() {
56     while(scanf("%d%d",&n,&k)==2) {
57         memset(hd,-1,sizeof hd),ne=0;
58         S=(1<<k)-1;
59         for(int i=1; i<=n; ++i)scanf("%d",&a[i]),a[i]--;
60         for(int i=1; i<n; ++i) {
61             int u,v;
62             scanf("%d%d",&u,&v);
63             addedge(u,v);
64             addedge(v,u);
65         }
66         tot=ans=0;
67         dfs1(1,-1),dfs2(1,newnode());
68         printf("%lld\n",ans);
69     }
70     return 0;
71 }
View Code

还有Menhera大佬利用基的FMT性质将时间复杂度优化到$O(n2^k)$的做法,看样子有点像是容斥。我个人更倾向于这一种,于是在这个思想的基础上进一步地分析:

题目要求的是包含所有颜色的路径条数。如果包含某个元素集合的路径不太好求,那么不包含某个元素集合的呢?只要把属于这个集合的结点都染成白色,不属于的都染成黑色,则问题就转化成了求一棵树上包含的所有点都是黑色的路径条数,直接dp求一下就行了。于是我们可以利用容斥原理,用所有的路径数减去不包含1个元素集合的路径数,再加上不包含2个元素集合的路径数,再减去不包含3个...就得到了答案。

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int N=5e4+10,inf=0x3f3f3f3f;
 5 int n,k,a[N],siz[N],hd[N],ne,ppc[1<<10];
 6 ll ans;
 7 struct E {int v,nxt;} e[N<<1];
 8 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
 9 void dfs(int u,int fa,int f) {
10     for(int i=hd[u]; ~i; i=e[i].nxt)if(e[i].v!=fa)dfs(e[i].v,u,f);
11     if(!siz[u])return;
12     ans+=f;
13     for(int i=hd[u]; ~i; i=e[i].nxt)if(e[i].v!=fa) {
14             int v=e[i].v;
15             ans+=(ll)siz[v]*siz[u]*2*f;
16             siz[u]+=siz[v];
17         }
18 }
19 ll solve() {
20     ans=0;
21     for(int S=(1<<k)-1; S; --S) {
22         int f=(k-ppc[S])&1?-1:1;
23         for(int i=1; i<=n; ++i)siz[i]=S>>a[i]&1;
24         dfs(1,-1,f);
25     }
26     return ans;
27 }
28 int main() {
29     ppc[0]=0;
30     for(int i=1; i<(1<<10); ++i)ppc[i]=ppc[i>>1]+(i&1);
31     while(scanf("%d%d",&n,&k)==2) {
32         memset(hd,-1,sizeof hd),ne=0;
33         for(int i=1; i<=n; ++i)scanf("%d",&a[i]),a[i]--;
34         for(int i=1; i<n; ++i) {
35             int u,v;
36             scanf("%d%d",&u,&v);
37             addedge(u,v);
38             addedge(v,u);
39         }
40         printf("%lld\n",solve());
41     }
42     return 0;
43 }
View Code

这种方法的复杂度为什么会比FWT少了k呢?这个k哪里去了呢?我想大概是在FWT的过程中把所有集合的dp值都求出来了,而我们只需要求全集的dp值,因此多做了许多无用功。

 (ps:由于题目数据的限制,用map优化的点分治可能会更快一些)

posted @ 2019-03-19 12:55  jrltx  阅读(378)  评论(0编辑  收藏  举报