ZJOI2018 线图
妈妈我终于会写这玩意了~为了纪念一下特地在UOJ上把差评改成好评
官方题解已经写得很清楚了,这里稍微补充一下
计算每棵树对应的节点个数时,所谓容斥就是减去所有联通子图的答案
所谓DP的时候减去自同构个数就是除以(s1!*s2!……)其中s1,s2代表同构子树个数
一个很重要的优化是如果两棵树是不同构的有根树 是同构的无根树 不要重复计算答案
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include<cstdio> 2 #include<map> 3 #include<vector> 4 #include<cassert> 5 #include<bitset> 6 #include<ctime> 7 #include<iostream> 8 #include<algorithm> 9 using namespace std; 10 vector<int> E[5005]; 11 vector<int> kind[1250]; 12 vector<int> T[15]; 13 int rnd[2005]; 14 int dp[5005][1250]; 15 int crt[20] , ntd[1250]; 16 int C[5005][20]; 17 int inv[21]; 18 const int mod = 998244353; 19 const int inv6 = 166374059; 20 int n , kk , ans = 0 , limit , cnt = 0; 21 int p[20] , tot = 0 , e_tot = 0; 22 map<int,int> mp; 23 map<long long,int> hashh; 24 map<pair<int,int> , int> edge; 25 int seed = 91478513 , a = 16554871 , b = 35659598; 26 int power(int a,int b) 27 { 28 int temp = a , ans = 1; 29 while(b){ 30 if(b&1) ans = (1LL * ans * temp) % mod; 31 temp = (1LL * temp * temp) % mod; 32 b >>= 1; 33 } 34 return ans; 35 } 36 inline int rand() 37 { 38 return seed = (1LL * (seed ^ a) * b) % mod; 39 } 40 int get_hash(int fa,int u) 41 { 42 long long h = 1; 43 vector<int> hh; 44 for(int i = 0;i < T[u].size();i++){ 45 if(T[u][i] != fa){ 46 int g = get_hash(u , T[u][i]); 47 h += rnd[g]; 48 hh.push_back(g); 49 } 50 } 51 map<long long , int>::iterator it = hashh.find(h); 52 if(it != hashh.end()) return it->second; 53 hashh.insert(pair<long long,int>{h , ++tot}); 54 kind[tot] = hh;ntd[tot] = 1; 55 sort(kind[tot].begin() , kind[tot].end()); 56 int qq = 1 , start = 0;ntd[tot] = 1; 57 for(;start < kind[tot].size() && kind[tot][start] == 1;start++); 58 for(int i = start + 1;i < kind[tot].size();i++){ 59 if(kind[tot][i] == 1) continue; 60 if(kind[tot][i] == kind[tot][i - 1]) qq++; 61 else{ 62 ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod; 63 qq = 1; 64 } 65 } 66 ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod; 67 return tot; 68 } 69 int get_p(int fa,int u) 70 { 71 long long h = 1; 72 for(int i = 0;i < T[u].size();i++){ 73 if(T[u][i] != fa){ 74 int g = get_p(u , T[u][i]); 75 if(g == -1) return -1; 76 h += rnd[g]; 77 } 78 } 79 map<long long , int>::iterator it = hashh.find(h); 80 if(it != hashh.end()) return it->second; 81 return -1; 82 } 83 inline int get_num(int u,int v) 84 { 85 if(u > v) swap(u , v); 86 map<pair<int,int> , int>::iterator it = edge.find(pair<int,int>{u , v}); 87 if(it != edge.end()) return it->second; 88 e_tot++; 89 edge.insert(pair<pair<int,int>,int>{pair<int,int>{u , v} , e_tot}); 90 return e_tot; 91 } 92 int get_node(int cnode , int k , vector<int> G[]) 93 { 94 edge.clear();e_tot = 0; 95 int q = 0; 96 for(int i = 1;i <= cnode;i++) q += G[i].size(); 97 q /= 2; 98 if(k == 2){ 99 int ans = 0; 100 for(int i = 1;i <= cnode;i++) ans = (ans + 1LL * (G[i].size() - 1) * G[i].size()) % mod; 101 return (1LL * ans * inv[2]); 102 } 103 if(k == 3){ 104 int ans = 0; 105 for(int i = 1;i <= cnode;i++){ 106 for(int j = 0;j < G[i].size();j++){ 107 if(i > G[i][j]) continue; 108 if(G[i].size() + G[G[i][j]].size() < 4) continue; 109 int d = G[i].size() + G[G[i][j]].size() - 2; 110 ans = (ans + 1LL * d * (d - 1) % mod * inv[2]) % mod; 111 } 112 } 113 return ans; 114 } 115 if(k == 4){ 116 vector<int> ed[cnode + 1]; 117 int ans = 0; 118 for(int i = 1;i <= cnode;i++){ 119 for(int j = 0;j < G[i].size();j++){ 120 if(i > G[i][j]) continue; 121 int d0 = G[i].size()+G[G[i][j]].size()-2; 122 ed[i].push_back (d0 - 1), ed[G[i][j]].push_back (d0 - 1); 123 if (d0 > 1) ans = (ans + 1LL * d0 * (d0 - 1) % mod * (d0 - 2) % mod) % mod; 124 } 125 } 126 for(int i = 1;i <= cnode;i++){ 127 long long x = 0, y = 0; 128 for (int j = 0; j < ed[i].size(); ++ j) if (ed[i][j]>0) 129 x = (x + ed[i][j]) % mod, y = (y+1LL * (ed[i][j] * ed[i][j]) % mod) % mod; 130 ans = (ans + (x * x % mod - y + mod) % mod) % mod; 131 } 132 return 1LL * ans * inv[2] % mod; 133 } 134 vector<int> G2[q + 1]; 135 for(int i = 1;i <= cnode;i++){ 136 if(G[i].size() < 2) continue; 137 int pst[G[i].size()]; 138 for(int j = 0;j < G[i].size();j++) pst[j] = get_num(i , G[i][j]); 139 for(int j = 0;j < G[i].size();j++){ 140 for(int k = j + 1;k < G[i].size();k++){ 141 G2[pst[j]].push_back(pst[k]);G2[pst[k]].push_back(pst[j]); 142 } 143 } 144 } 145 return get_node(e_tot , k - 1 , G2); 146 } 147 bool connect[(1<<10) + 1]; 148 int neigh[20]; 149 int brout(vector<int> G[]) 150 { 151 for(int i = 0;i < (1<<limit);i++) connect[i] = 0; 152 for(int i = 1;i <= limit;i++){ 153 neigh[i] = 0; 154 for(int j = 0;j < G[i].size();j++){ 155 neigh[i] |= (1<<G[i][j]-1); 156 } 157 connect[1<<i-1] = 1; 158 } 159 int ans = 0; 160 for(int i = 1;i < (1<<limit) - 1;i++){ 161 if(!connect[i]) continue; 162 int mask = 0 , pp = 0; 163 for(int j = 1;j <= limit;j++){ 164 if((i >> j-1) & 1) mask |= neigh[j]; 165 } 166 for(int j = 1;j <= limit;j++){ 167 if((mask >> j - 1) & 1) {connect[i | (1<<j-1)] = 1;} 168 } 169 if(i == (i&-i)) continue; 170 for(int j = 1;j <= limit;j++){ 171 T[j].clear(); 172 if(((i >> j-1) & 1) == 0) continue; 173 if(!pp) pp = j; 174 for(int k = 0;k < G[j].size();k++){ 175 if((i>>G[j][k]-1) & 1) {T[j].push_back(G[j][k]);} 176 } 177 } 178 ans = (ans + mp[get_hash(0 , pp)]) % mod; 179 } 180 return ans; 181 } 182 void count() 183 { 184 for(int i = 1;i <= limit;i++) T[i].clear(); 185 for(int i = 2;i <= limit;i++){ 186 T[i].push_back(p[i]);T[p[i]].push_back(i); 187 } 188 int h = get_hash(0 , 1); 189 map<int,int>::iterator it = mp.find(h); 190 if(it != mp.end()) return; 191 for(int i = 2;i <= limit;i++){ 192 int g = get_p(0 , i); 193 if(g == -1) continue; 194 map<int,int>::iterator it = mp.find(g); 195 if(it != mp.end()) {mp.insert(pair<int,int>{h , it->second});return;} 196 } 197 vector<int> W[limit + 1]; 198 for(int i = 1;i <= limit;i++){ 199 for(int j = 0;j < T[i].size();j++) W[i].push_back(T[i][j]); 200 } 201 int t = get_node(limit , kk , T); 202 int g = brout(W); 203 t = (t + mod - g) % mod; 204 mp.insert(pair<int,int>{h , t}); 205 return; 206 } 207 void dfs(int x) 208 { 209 if(x == limit){ 210 count();return; 211 } 212 for(int i = 1;i <= x;i++){ 213 p[x + 1] = i;dfs(x + 1); 214 } 215 return; 216 } 217 void find(int fa,int u) 218 { 219 dp[u][1] = 1; 220 for(int i = 0;i < E[u].size();i++){ 221 if(E[u][i] != fa) find(u , E[u][i]); 222 } 223 if(E[u].size() == 1 && fa != 0) return; 224 int siz = (fa == 0) ? E[u].size() : E[u].size() - 1; 225 int cop[siz + 1][(1<<kk)+1]; 226 for(int i = 2;i <= cnt;i++){ 227 int len = 0 , g = 1 , pcnt = 0; 228 for(int j = 0;j < kind[i].size();j++){ 229 if(kind[i][j] != 1){ 230 crt[++pcnt] = j; 231 ++len; 232 } 233 } 234 for(int j = 0;j <= siz;j++){ 235 for(int k = 0;k < (1<<len);k++) cop[j][k] = 0; 236 } 237 cop[0][0] = 1; 238 for(int j = 0;j < E[u].size();j++){ 239 if(E[u][j] == fa) continue; 240 cop[g][0] = 1; 241 for(int k = 1;k < (1<<len);k++){ 242 cop[g][k] = cop[g - 1][k]; 243 for(int p = 0;p < len;p++){ 244 if((k>>p) & 1){ 245 cop[g][k] = (cop[g][k] + 1LL * cop[g - 1][k ^ (1<<p)] * dp[E[u][j]][kind[i][crt[p+1]]]) % mod; 246 } 247 } 248 } 249 g++; 250 } 251 if(siz >= kind[i].size()) dp[u][i] = (1LL * cop[siz][(1<<len) - 1] * C[siz - len][kind[i].size() - len]) % mod; 252 dp[u][i] = (1LL * dp[u][i] * ntd[i]) % mod; 253 } 254 return; 255 } 256 int main() 257 { 258 scanf("%d%d",&n,&kk); 259 C[0][0] = 1; 260 for(int i = 1;i <= n;i++){ 261 C[i][0] = 1; 262 for(int j = 1;j <= i && j <= 15;j++){ 263 C[i][j] = (C[i-1][j] + C[i-1][j - 1]) % mod; 264 } 265 } 266 int g = 1;inv[1] = inv[0] = 1; 267 for(int i = 2;i <= 20;i++){ 268 g = (1LL * g * i) % mod; 269 inv[i] = power(g , mod - 2); 270 } 271 for(int i = 0;i <= 2000;i++) rnd[i] = rand(); 272 for(int i = 1;i < n;i++){ 273 int u , v;scanf("%d%d",&u,&v); 274 E[u].push_back(v); 275 E[v].push_back(u); 276 } 277 mp.insert(pair<int,int>{1 , 0}); 278 for(int i = 2;i <= kk + 1;i++){ 279 limit = i; 280 dfs(1); 281 }cnt = tot; 282 find(0 , 1); 283 int ans = 0; 284 for(int i = 1;i <= n;i++){ 285 for(int j = 1;j <= cnt;j++){ 286 ans = (ans + 1LL * dp[i][j] * mp[j]) % mod; 287 } 288 } 289 printf("%d\n",ans); 290 return 0; 291 }