洛谷P5206 数树
题意:
task0,给定两棵树T1,T2,取它们公共边(两端点相同)加入一张新的图,记新图连通块个数为x,求yx。
task1,给定T1,求所有T2的task0之和。
task2,求所有T1的task1之和。
解:y = 1的时候特殊处理,就是总方案数。
task0,显然按照题意模拟即可。
task1,对某个T2,设有k条边相同,那么连通块数就是n - k。要求的就是
对于每个T2,前面yn都是一样的,所以直接去掉,最后乘上即可。关注后面这个东西怎么求。令y' = 1/y,E是公共边集。
注意到
这里下式是枚举边集E的子集S,对每个的子集贡献求和。
注意上式先枚举a再求组合数,相当于枚举在边集里选a条边,然后枚举选哪a条边。也就是枚举子集。
也就是下面这段话想表达的。摘自Cyhlnj。里面还提到了一个n3的矩阵树定理做法,神奇。
容斥写法在下不会T_T
下一步,把S提前枚举,在不同的E中同一个S的贡献总是相同的。考虑一个S会对哪些E产生贡献,也就是它的贡献会被计算多少次。
这|S|条边会形成若干个连通块。这些连通块通过加上一些边可以形成树。这些新边没有任何限制,于是就是连通块的生成树计数。
这里又有若干种推法......个人认为最简单的是利用prufer序列求解。
摘自Joker_69。
令z = y' - 1,m = 边集为S时的连通块数 = n - |S|,第i号连通块有ai个点,于是我们的答案变成了这样:
这个东西怎么求呢?注意到在T1中选择任意的边集S等价于把T1划分为若干个连通块,用这些边连起来。于是就考虑树形DP。
这后面这个求积,要乘上每个连通块的大小,有个暴力是f[x][i]表示x为根,x所在连通块大小为i的所有方案权值和。
n2过不了,于是换个思路就是在每个连通块中选一个关键点的方案数。
因为是以联通块为依据DP,所以变形一下,加上之前忽略的yn,我们有:
于是状态设计有f[x][0/1]表示在以x为根的子树中,x所在连通块是否选择了关键点的所有方案权值之和。
每个联通块的贡献是z-1n,且我们只在关键点被选出来的那一瞬间计算这个联通块的贡献。
同时由于每个连通块的贡献要乘起来,那么所有方案之和还是要乘起来,等价于每个方案两两求积再相加。
口胡了半天还是写一下方程吧。
f[x][0] = 1; f[x][1] = invz * n % MO; LL t0 = f[x][0] * f[y][1] % MO + f[x][0] * f[y][0] % MO; LL t1 = f[x][0] * f[y][1] % MO + f[x][1] * f[y][0] % MO + f[x][1] * f[y][1] % MO; f[x][0] = t0 % MO; f[x][1] = t1 % MO;
task2:问题变得严重起来......
跟task1一样,对于某个T1和T2的组合,它的贡献仍能拆成它的子集的贡献。
设g(E)为给定E这个边集之后的生成树个数,由task1可得g(E) = nm-2∏ai。
枚举E为一定相同的边集,剩下的边随便连。那么对于T1的g(E)种情况,T2都有g(E)种情况。
所以E这个边集的贡献为z|E|g2(E)。
m还是连通块数,我们暴力展开g(E),并把与m有关的项放到∏里面,无关的提到外面,令r = n2/z,那么答案就是:
接下来这一步很毒瘤...我们考虑这个式子有什么实际意义。
前面的边集把这个图分成了若干个森林。每个连通块一定是树。后面相当于给每个连通块赋了ai2r的权值,并把权值乘起来作为这个边集的贡献。
设fi为大小为i的树的贡献,对应的EGF是F(x),gi为大小为i的图的贡献,对应的EGF是G(x)
那么有这样一个式子:G(x) = eF(x)
考虑fi是多少:每种树的权值都是i2r,一共有ii-2种树,贡献加起来是iir。
这样就对F(x)做exp,然后拿G(x)的第n项出来搞一搞就是答案了。
多项式操作别写错了......我一开始WA了20分是因为有这样的一句话:n * n
然后两个n乘起来爆int了......这题神奇的一批...
1 #include <cstdio> 2 #include <algorithm> 3 #include <cstring> 4 5 typedef long long LL; 6 const int N = 100010; 7 const LL MO = 998244353; 8 9 inline LL qpow(LL a, LL b) { 10 LL ans = 1; 11 a %= MO; 12 while(b) { 13 if(b & 1) ans = ans * a % MO; 14 a = a * a % MO; 15 b = b >> 1; 16 } 17 return ans; 18 } 19 20 struct Edge { 21 int nex, v; 22 }edge[N << 1]; int tp; 23 24 int n, e[N]; 25 LL Y, z; 26 27 inline void add(int x, int y) { 28 tp++; 29 edge[tp].v = y; 30 edge[tp].nex = e[x]; 31 e[x] = tp; 32 return; 33 } 34 35 namespace t0 { 36 int fa[N]; 37 void DFS(int x, int f) { 38 fa[x] = f; 39 for(int i = e[x]; i; i = edge[i].nex) { 40 int y = edge[i].v; 41 if(y == f) continue; 42 DFS(y, x); 43 } 44 return; 45 } 46 inline void solve() { 47 if(Y == 1) { 48 puts("1"); 49 return; 50 } 51 for(int i = 1, x, y; i < n; i++) { 52 scanf("%d%d", &x, &y); 53 add(x, y); 54 add(y, x); 55 } 56 DFS(1, 0); 57 int k = 0; 58 for(int i = 1, x, y; i < n; i++) { 59 scanf("%d%d", &x, &y); 60 if(fa[x] == y || fa[y] == x) { 61 k++; 62 } 63 } 64 LL ans = qpow(Y, n - k); 65 printf("%lld\n", ans); 66 return; 67 } 68 } 69 70 namespace t1 { 71 LL f[N][2], invz; 72 void DFS(int x, int father) { 73 f[x][0] = 1; 74 f[x][1] = invz * n % MO; 75 //printf("x = %d fa = %d \n", x, father); 76 for(int i = e[x]; i; i = edge[i].nex) { 77 int y = edge[i].v; 78 //printf("y = %d \n", y); 79 if(y == father) continue; 80 DFS(y, x); 81 LL t0 = f[x][0] * f[y][1] % MO + f[x][0] * f[y][0] % MO; 82 LL t1 = f[x][0] * f[y][1] % MO + f[x][1] * f[y][0] % MO + f[x][1] * f[y][1] % MO; 83 f[x][0] = t0 % MO; 84 f[x][1] = t1 % MO; 85 } 86 //printf("X = %d f[x][0] = %lld f[x][1] = %lld \n", x, f[x][0], f[x][1]); 87 return; 88 } 89 inline void solve() { 90 if(Y == 1) { 91 LL ans = qpow(n, n - 2); 92 printf("%lld\n", ans); 93 return; 94 } 95 z = qpow(Y, MO - 2); z = (z - 1 + MO) % MO; 96 invz = qpow(z, MO - 2); 97 for(int i = 1, x, y; i < n; i++) { 98 scanf("%d%d", &x, &y); 99 add(x, y); add(y, x); 100 } 101 DFS(1, 0); 102 LL ans = f[1][1] * qpow(n, MO - 3) % MO * qpow(z, n) % MO * qpow(Y, n) % MO; 103 printf("%lld\n", ans); 104 return; 105 } 106 } 107 108 namespace t2 { 109 110 typedef LL arr[N * 4]; 111 const LL G = 3; 112 113 int r[N * 4]; 114 arr A, B, a, b, inv_t, exp_t, ln_t, ln_t2; 115 LL pw[N]; 116 117 inline void prework(int n) { 118 static int R = 0; 119 if(R == n) return; 120 R = n; 121 int lm = 1; 122 while((1 << lm) < n) lm++; 123 for(int i = 0; i < n; i++) { 124 r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1)); 125 } 126 return; 127 } 128 129 inline void NTT(LL *a, int n, int f) { 130 prework(n); 131 for(int i = 0; i < n; i++) { 132 if(i < r[i]) std::swap(a[i], a[r[i]]); 133 } 134 for(int len = 1; len < n; len <<= 1) { 135 LL Wn = qpow(G, (MO - 1) / (len << 1)); 136 if(f == -1) Wn = qpow(Wn, MO - 2); 137 for(int i = 0; i < n; i += (len << 1)) { 138 LL w = 1; 139 for(int j = 0; j < len; j++) { 140 LL t = a[i + len + j] * w % MO; 141 a[i + len + j] = (a[i + j] - t) % MO; 142 a[i + j] = (a[i + j] + t) % MO; 143 w = w * Wn % MO; 144 } 145 } 146 } 147 if(f == -1) { 148 LL inv = qpow(n, MO - 2); 149 for(int i = 0; i < n; i++) { 150 a[i] = a[i] * inv % MO; 151 } 152 } 153 return; 154 } 155 156 void Inv(const LL *a, LL *b, int n) { 157 if(n == 1) { 158 b[0] = qpow(a[0], MO - 2); 159 b[1] = 0; 160 return; 161 } 162 Inv(a, b, n >> 1); 163 /// ans = b[i] * (2 - a[i] * b[i]) 164 memcpy(A, a, n * sizeof(LL)); memset(A + n, 0, n * sizeof(LL)); 165 memcpy(B, b, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL)); 166 NTT(A, n << 1, 1); NTT(B, n << 1, 1); 167 for(int i = 0; i < (n << 1); i++) b[i] = B[i] * (2 - A[i] * B[i] % MO) % MO; 168 NTT(b, n << 1, -1); 169 memset(b + n, 0, n * sizeof(LL)); 170 return; 171 } 172 173 inline void getInv(const LL *a, LL *b, int n) { 174 int len = 1; 175 while(len < n) len <<= 1; 176 memcpy(inv_t, a, n * sizeof(LL)); memset(inv_t + n, 0, (len - n) * sizeof(LL)); 177 Inv(inv_t, b, len); 178 memset(b + n, 0, (len - n) * sizeof(LL)); 179 return; 180 } 181 182 inline void der(const LL *a, LL *b, int n) { 183 for(int i = 0; i < n - 1; i++) { 184 b[i] = a[i + 1] * (i + 1) % MO; 185 } 186 b[n - 1] = 0; 187 return; 188 } 189 190 inline void ter(const LL *a, LL *b, int n) { 191 for(int i = n - 1; i >= 1; i--) { 192 b[i] = a[i - 1] * qpow(i, MO - 2) % MO; 193 } 194 b[0] = 0; 195 return; 196 } 197 198 inline void getLn(const LL *a, LL *b, int n) { 199 getInv(a, ln_t, n); 200 der(a, ln_t2, n); 201 int len = 1; 202 while(len < 2 * n) len <<= 1; 203 memset(ln_t + n, 0, (len - n) * sizeof(LL)); 204 memset(ln_t2 + n, 0, (len - n) * sizeof(LL)); 205 NTT(ln_t, len, 1); NTT(ln_t2, len, 1); 206 for(int i = 0; i < len; i++) b[i] = ln_t[i] * ln_t2[i] % MO; 207 NTT(b, len, -1); 208 memset(b + n, 0, (len - n) * sizeof(LL)); 209 ter(b, b, n); 210 return; 211 } 212 213 void Exp(const LL *a, LL *b, int n) { 214 if(n == 1) { 215 b[0] = 1; 216 b[1] = 0; 217 return; 218 } 219 Exp(a, b, n >> 1); 220 /// ans = b * (1 + a - ln b) 221 getLn(b, exp_t, n); 222 for(int i = 0; i < n; i++) A[i] = (a[i] - exp_t[i]) % MO; 223 A[0] = (A[0] + 1) % MO; 224 memset(A + n, 0, n * sizeof(LL)); 225 memcpy(B, b, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL)); 226 NTT(A, n << 1, 1); NTT(B, n << 1, 1); 227 for(int i = 0; i < (n << 1); i++) b[i] = A[i] * B[i] % MO; 228 NTT(b, n << 1, -1); 229 memset(b + n, 0, n * sizeof(LL)); 230 return; 231 } 232 233 inline void getExp(const LL *a, LL *b, int n) { 234 int len = 1; 235 while(len < n) len <<= 1; 236 Exp(a, b, len); 237 memset(b + n, 0, (len - n) * sizeof(LL)); 238 return; 239 } 240 241 inline void solve() { 242 if(Y == 1) { 243 LL t = qpow(n, n - 2); 244 printf("%lld\n", t * t % MO); 245 return; 246 } 247 248 LL z = (qpow(Y, MO - 2) - 1) % MO; 249 LL r = 1ll * n * n % MO * qpow(z, MO - 2) % MO; 250 251 pw[0] = 1; 252 for(int i = 1; i <= n; i++) { 253 pw[i] = pw[i - 1] * i % MO; 254 a[i] = qpow(i, i) * r % MO * qpow(pw[i], MO - 2) % MO; 255 } 256 getExp(a, b, n + 1); 257 LL ans = b[n] * pw[n] % MO; 258 ans = ans * qpow(Y, n) % MO * qpow(z, n) % MO * qpow(n, MO - 5) % MO; 259 printf("%lld\n", (ans + MO) % MO); 260 return; 261 } 262 } 263 264 int main() { 265 266 int f; 267 scanf("%d%lld%d", &n, &Y, &f); 268 if(f == 0) { 269 t0::solve(); 270 return 0; 271 } 272 if(f == 1) { 273 t1::solve(); 274 return 0; 275 } 276 t2::solve(); 277 return 0; 278 }
以蒟蒻视角写了题解,以后还要继续努力!
感谢: