ZJOI2018 线图

妈妈我终于会写这玩意了~为了纪念一下特地在UOJ上把差评改成好评

官方题解已经写得很清楚了,这里稍微补充一下

计算每棵树对应的节点个数时,所谓容斥就是减去所有联通子图的答案

所谓DP的时候减去自同构个数就是除以(s1!*s2!……)其中s1,s2代表同构子树个数

一个很重要的优化是如果两棵树是不同构的有根树 是同构的无根树 不要重复计算答案

  1 #include<cstdio>
  2 #include<map>
  3 #include<vector>
  4 #include<cassert>
  5 #include<bitset>
  6 #include<ctime>
  7 #include<iostream>
  8 #include<algorithm>
  9 using namespace std;
 10 vector<int> E[5005];
 11 vector<int> kind[1250];
 12 vector<int> T[15];
 13 int rnd[2005];
 14 int dp[5005][1250];
 15 int crt[20] , ntd[1250];
 16 int C[5005][20];
 17 int inv[21];
 18 const int mod = 998244353;
 19 const int inv6 = 166374059;
 20 int n , kk , ans = 0 , limit , cnt = 0;
 21 int p[20] , tot = 0 , e_tot = 0;
 22 map<int,int> mp;
 23 map<long long,int> hashh;
 24 map<pair<int,int> , int> edge;
 25 int seed = 91478513 , a = 16554871 , b = 35659598;
 26 int power(int a,int b)
 27 {
 28     int temp = a , ans = 1;
 29     while(b){
 30         if(b&1) ans = (1LL * ans * temp) % mod;
 31         temp = (1LL * temp * temp) % mod;
 32         b >>= 1;
 33     }
 34     return ans;
 35 }
 36 inline int rand()
 37 {
 38     return seed = (1LL * (seed ^ a) * b) % mod;
 39 }
 40 int get_hash(int fa,int u)
 41 {
 42     long long h = 1;
 43     vector<int> hh;
 44     for(int i = 0;i < T[u].size();i++){
 45         if(T[u][i] != fa){
 46             int g = get_hash(u , T[u][i]);
 47             h += rnd[g];
 48             hh.push_back(g);
 49         }
 50     }
 51     map<long long , int>::iterator it = hashh.find(h);
 52     if(it != hashh.end()) return it->second;
 53     hashh.insert(pair<long long,int>{h , ++tot});
 54     kind[tot] = hh;ntd[tot] = 1;
 55     sort(kind[tot].begin() , kind[tot].end());
 56     int qq = 1 , start = 0;ntd[tot] = 1;
 57     for(;start < kind[tot].size() && kind[tot][start] == 1;start++);
 58     for(int i = start + 1;i < kind[tot].size();i++){
 59         if(kind[tot][i] == 1) continue;
 60         if(kind[tot][i] == kind[tot][i - 1]) qq++;
 61         else{
 62             ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod;
 63             qq = 1;
 64         }
 65     }
 66     ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod;
 67     return tot;
 68 }
 69 int get_p(int fa,int u)
 70 {
 71     long long h = 1;
 72     for(int i = 0;i < T[u].size();i++){
 73         if(T[u][i] != fa){
 74             int g = get_p(u , T[u][i]);
 75             if(g == -1) return -1;
 76             h += rnd[g];
 77         }
 78     }
 79     map<long long , int>::iterator it = hashh.find(h);
 80     if(it != hashh.end()) return it->second;
 81     return -1;
 82 }
 83 inline int get_num(int u,int v)
 84 {
 85     if(u > v) swap(u , v);
 86     map<pair<int,int> , int>::iterator it = edge.find(pair<int,int>{u , v});
 87     if(it != edge.end()) return it->second;
 88     e_tot++;
 89     edge.insert(pair<pair<int,int>,int>{pair<int,int>{u , v} , e_tot});
 90     return e_tot;
 91 }
 92 int get_node(int cnode , int k , vector<int> G[])
 93 {
 94     edge.clear();e_tot = 0;
 95     int q = 0;
 96     for(int i = 1;i <= cnode;i++) q += G[i].size();
 97     q /= 2;
 98     if(k == 2){
 99         int ans = 0;
100         for(int i = 1;i <= cnode;i++) ans = (ans + 1LL * (G[i].size() - 1) * G[i].size()) % mod;
101         return (1LL * ans * inv[2]);
102     }
103     if(k == 3){
104         int ans = 0;
105         for(int i = 1;i <= cnode;i++){
106             for(int j = 0;j < G[i].size();j++){
107                 if(i > G[i][j]) continue;
108                 if(G[i].size() + G[G[i][j]].size() < 4) continue;
109                 int d = G[i].size() + G[G[i][j]].size() - 2;
110                 ans = (ans + 1LL * d * (d - 1) % mod * inv[2]) % mod;
111             }
112         }
113         return ans;
114     }
115     if(k == 4){
116         vector<int> ed[cnode + 1];
117         int ans = 0;
118         for(int i = 1;i <= cnode;i++){
119             for(int j = 0;j < G[i].size();j++){
120                 if(i > G[i][j]) continue;
121                 int d0 = G[i].size()+G[G[i][j]].size()-2;
122                 ed[i].push_back (d0 - 1), ed[G[i][j]].push_back (d0 - 1);
123                 if (d0 > 1) ans = (ans + 1LL * d0 * (d0 - 1) % mod * (d0 - 2) % mod) % mod;
124             }
125         }
126         for(int i = 1;i <= cnode;i++){
127             long long x = 0, y = 0;
128             for (int j = 0; j < ed[i].size(); ++ j)    if (ed[i][j]>0)
129                 x = (x + ed[i][j]) % mod, y = (y+1LL * (ed[i][j] * ed[i][j]) % mod) % mod;
130             ans = (ans + (x * x % mod - y + mod) % mod) % mod;
131         }
132         return 1LL * ans * inv[2] % mod;
133     }
134     vector<int> G2[q + 1];
135     for(int i = 1;i <= cnode;i++){
136         if(G[i].size() < 2) continue;
137         int pst[G[i].size()];
138         for(int j = 0;j < G[i].size();j++) pst[j] = get_num(i , G[i][j]);
139         for(int j = 0;j < G[i].size();j++){
140             for(int k = j + 1;k < G[i].size();k++){
141                 G2[pst[j]].push_back(pst[k]);G2[pst[k]].push_back(pst[j]);
142             }
143         }
144     }
145     return get_node(e_tot , k - 1 , G2);
146 }
147 bool connect[(1<<10) + 1];
148 int neigh[20];
149 int brout(vector<int> G[])
150 {
151     for(int i = 0;i < (1<<limit);i++) connect[i] = 0;
152     for(int i = 1;i <= limit;i++){
153         neigh[i] = 0;
154         for(int j = 0;j < G[i].size();j++){
155             neigh[i] |= (1<<G[i][j]-1);
156         }
157         connect[1<<i-1] = 1;
158     }
159     int ans = 0;
160     for(int i = 1;i < (1<<limit) - 1;i++){
161         if(!connect[i]) continue;
162         int mask = 0 , pp = 0;
163         for(int j = 1;j <= limit;j++){
164             if((i >> j-1) & 1) mask |= neigh[j];
165         }
166         for(int j = 1;j <= limit;j++){
167             if((mask >> j - 1) & 1) {connect[i | (1<<j-1)] = 1;}
168         }
169         if(i == (i&-i)) continue;
170         for(int j = 1;j <= limit;j++){
171             T[j].clear();
172             if(((i >> j-1) & 1) == 0) continue;
173             if(!pp) pp = j;
174             for(int k = 0;k < G[j].size();k++){
175                 if((i>>G[j][k]-1) & 1) {T[j].push_back(G[j][k]);}
176             }
177         }
178         ans = (ans + mp[get_hash(0 , pp)]) % mod;
179     }
180     return ans;
181 }
182 void count()
183 {
184     for(int i = 1;i <= limit;i++) T[i].clear();
185     for(int i = 2;i <= limit;i++){
186         T[i].push_back(p[i]);T[p[i]].push_back(i);
187     }
188     int h = get_hash(0 , 1);
189     map<int,int>::iterator it = mp.find(h);
190     if(it != mp.end()) return;
191     for(int i = 2;i <= limit;i++){
192         int g = get_p(0 , i);
193         if(g == -1) continue;
194         map<int,int>::iterator it = mp.find(g);
195         if(it != mp.end()) {mp.insert(pair<int,int>{h , it->second});return;}
196     }
197     vector<int> W[limit + 1];
198     for(int i = 1;i <= limit;i++){
199         for(int j = 0;j < T[i].size();j++) W[i].push_back(T[i][j]);
200     }
201     int t = get_node(limit , kk , T);
202     int g = brout(W);
203     t = (t + mod - g) % mod;
204     mp.insert(pair<int,int>{h , t});
205     return;
206 }
207 void dfs(int x)
208 {
209     if(x == limit){
210         count();return;
211     }
212     for(int i = 1;i <= x;i++){
213         p[x + 1] = i;dfs(x + 1);
214     }
215     return;
216 }
217 void find(int fa,int u)
218 {
219     dp[u][1] = 1;
220     for(int i = 0;i < E[u].size();i++){
221         if(E[u][i] != fa) find(u , E[u][i]);
222     }
223     if(E[u].size() == 1 && fa != 0) return;
224     int siz = (fa == 0) ? E[u].size() : E[u].size() - 1;
225     int cop[siz + 1][(1<<kk)+1];
226     for(int i = 2;i <= cnt;i++){
227         int len = 0 , g = 1 , pcnt = 0;
228         for(int j = 0;j < kind[i].size();j++){
229             if(kind[i][j] != 1){
230                 crt[++pcnt] = j;
231                 ++len;
232             }
233         }
234         for(int j = 0;j <= siz;j++){
235             for(int k = 0;k < (1<<len);k++) cop[j][k] = 0;
236         }
237         cop[0][0] = 1;
238         for(int j = 0;j < E[u].size();j++){
239             if(E[u][j] == fa) continue;
240             cop[g][0] = 1;
241             for(int k = 1;k < (1<<len);k++){
242                 cop[g][k] = cop[g - 1][k];
243                 for(int p = 0;p < len;p++){
244                     if((k>>p) & 1){
245                         cop[g][k] = (cop[g][k] + 1LL * cop[g - 1][k ^ (1<<p)] * dp[E[u][j]][kind[i][crt[p+1]]]) % mod;
246                     }
247                 }
248             }
249             g++;
250         }
251         if(siz >= kind[i].size())  dp[u][i] = (1LL * cop[siz][(1<<len) - 1] * C[siz - len][kind[i].size() - len]) % mod;
252         dp[u][i] = (1LL * dp[u][i] * ntd[i]) % mod;
253     }
254     return;
255 }
256 int main()
257 {
258     scanf("%d%d",&n,&kk);
259     C[0][0] = 1;
260     for(int i = 1;i <= n;i++){
261         C[i][0] = 1;
262         for(int j = 1;j <= i && j <= 15;j++){
263             C[i][j] = (C[i-1][j] + C[i-1][j - 1]) % mod;
264         }
265     }
266     int g = 1;inv[1] = inv[0] = 1;
267     for(int i = 2;i <= 20;i++){
268         g = (1LL * g * i) % mod;
269         inv[i] = power(g , mod - 2);
270     }
271     for(int i = 0;i <= 2000;i++) rnd[i] = rand();
272     for(int i = 1;i < n;i++){
273         int u , v;scanf("%d%d",&u,&v);
274         E[u].push_back(v);
275         E[v].push_back(u);
276     }
277     mp.insert(pair<int,int>{1 , 0});
278     for(int i = 2;i <= kk + 1;i++){
279         limit = i;
280         dfs(1);
281     }cnt = tot;
282     find(0 , 1);
283     int ans = 0;
284     for(int i = 1;i <= n;i++){
285         for(int j = 1;j <= cnt;j++){
286             ans = (ans + 1LL * dp[i][j] * mp[j]) % mod;
287         }
288     }
289     printf("%d\n",ans);
290     return 0;
291 }
View Code

 

posted @ 2018-05-31 13:09  Sugar!  阅读(667)  评论(0编辑  收藏  举报