hdu 5909 Tree Cutting FWT

显然的树状DP,dp[u][i]表示以u为根的子树中,价值为i的子树的数量。

转移也很显然。开始时,dp[u][val[u]]为1。每次多一个子树v,就是dp[u]' = dp[u] + ∑dp[u]*dp[v](卷积)。使用FWT处理即可。复杂度O(N*M*logM)

 1 #include <cstdio>
 2 using namespace std;
 3 typedef long long ll;
 4 const int MAXN = 1100,MAXM = 2100,mo = 1000000007;
 5 int head[MAXN],v[MAXN],to[MAXM],nxt[MAXM],res[1024],dp[MAXN][1024];
 6 int n,m,T,cnt,inv2;
 7 int fpow(int x,int k)
 8 {
 9     if (k == 1)
10         return x;
11     int t = fpow(x,k >> 1);
12     if (k & 1)
13         return (ll)t * t % mo * x % mo;
14     return (ll)t * t % mo;
15 }
16 void add(int x,int y)
17 {
18     nxt[++cnt] = head[x];
19     to[cnt] = y;
20     head[x] = cnt;
21 }
22 void fwt(int *a,int len,bool inv)
23 {
24     for(int d = 1;d < len;d <<= 1)
25     {
26         for(int h = d << 1,i = 0;i <= len - 1;i += h)
27             for(int j = 0;j < d;j++)
28             {
29                 int x = a[i + j],y = a[i + j + d];
30                 a[i + j] = (x + y) % mo;
31                 a[i + j + d]=(x - y + mo) % mo;
32             }
33         if (inv == true)    
34             for (int i = 0;i <= len - 1;i++)
35                 a[i] = (ll)a[i] * inv2 % mo;
36     }
37 }
38 void dfs(int x,int frm)
39 {
40     int tmp1[1024],tmp2[1024];
41     dp[x][v[x]] = 1;
42     for (int i = head[x];i;i = nxt[i])
43     {
44         if (to[i] == frm)
45             continue;
46         dfs(to[i],x);
47         for (int o = 0;o <= m - 1;o++)
48         {
49             tmp1[o] = dp[x][o]; 
50             tmp2[o] = dp[to[i]][o];
51         }
52         fwt(tmp1,m,false);
53         fwt(tmp2,m,false);
54         for (int o = 0;o <= m - 1;o++)
55             tmp1[o] = (ll)tmp1[o] * tmp2[o] % mo;
56         fwt(tmp1,m,true);
57         for (int o = 0;o <= m - 1;o++)
58             dp[x][o] = (dp[x][o] + tmp1[o]) % mo;
59     }
60 }
61 int main()
62 {
63     inv2 = fpow(2,mo - 2);
64     for (scanf("%d",&T);T != 0;T--)
65     {
66         cnt = 0;
67         scanf("%d%d",&n,&m);
68         for (int i = 1;i <= n;i++)
69             scanf("%d",&v[i]);
70         int tx,ty;
71         for (int i = 1;i <= n - 1;i++)
72         {
73             scanf("%d%d",&tx,&ty);
74             add(tx,ty);
75             add(ty,tx); 
76         }
77         dfs(1,0);
78         for (int i = 1;i <= n;i++)
79             for (int j = 0;j <= m - 1;j++)
80                 res[j] = (res[j] + dp[i][j]) % mo;
81         for (int i = 0;i <= m - 2;i++)
82             printf("%d ",res[i]);
83         printf("%d\n",res[m - 1]); 
84         for (int i = 1;i <= cnt;i++)
85             nxt[i] = 0;
86         for (int i = 1;i <= n;i++)
87             head[i] = 0;
88         cnt = 0;
89         for (int i = 1;i <= n;i++)
90             for (int j = 0;j <= m - 1;j++)
91                 dp[i][j] = 0;
92         for (int i = 0;i <= m - 1;i++)
93             res[i] = 0;
94     }
95     return 0;
96 }

 

posted @ 2019-08-16 11:04  IAT14  阅读(414)  评论(0编辑  收藏  举报