LOJ#2320 生成树计数
解:讲一个别的题解里我比较难以理解的地方,就是为什么可以把这两个东西合起来看成某一个连通块指数是2m而别的指数都是m。
其实很好理解,但是别人都略过了......把后面的∑提到∏的前面,然后展开,也可以理解成把∏塞到∑里面。
然后我们就发现对于每个生成树,我们其实有n种选择,分别把某个块的次数变成2m,且每种选择都作为一棵生成树计入贡献,且这回的贡献,一个树内部各个块全部是乘积的形式。
发现贡献与度数有关,又要求所有生成树,于是考虑prufer序列。
如何看待每个点是一个连通块?就是对于一种生成树,实际方案要乘上(点数度数)。代表每条边连哪个点。
这样一个生成树的权值是这个东西:
接下来考虑钦定每个连通块的度数。在prufer序列中每个数的出现次数是度数-1,于是令di = di - 1
首先要把这些点在prufer中排列一下,于是有:
接下来一顿操作,把di!塞到∏里面,(n-2)!提出来,就有个式子。
然后考虑生成函数,推荐这个。
关于求等幂和这个式子,实在是不理解...
这东西还要我用vector写多项式......
然后搞来搞去,不管了...
1 #include <cstdio> 2 #include <algorithm> 3 #include <cstring> 4 #include <cmath> 5 #include <vector> 6 7 typedef long long LL; 8 const int N = 100010; 9 const LL MO = 998244353, G = 3; 10 typedef LL arr[N << 2]; 11 typedef std::vector<LL> Poly; 12 13 arr A, B, exp_t, inv_t, inv_t2, f, g, h, p, ex; 14 int r[N << 2], n; 15 LL m, a[N], pw[N]; 16 17 inline LL qpow(LL a, LL b) { 18 a %= MO; 19 LL ans = 1; 20 while(b) { 21 if(b & 1) ans = ans * a % MO; 22 a = a * a % MO; 23 b = b >> 1; 24 } 25 return ans; 26 } 27 28 inline void prework(int n) { 29 static int R = 0; 30 if(R == n) return; 31 R = n; 32 int lm = 1; 33 while((1 << lm) < n) lm++; 34 for(int i = 0; i < n; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1)); 35 return; 36 } 37 38 inline void NTT(LL *a, int n, int f) { 39 prework(n); 40 for(int i = 0; i < n; i++) { 41 if(i < r[i]) std::swap(a[i], a[r[i]]); 42 } 43 for(int len = 1; len < n; len <<= 1) { 44 LL Wn = qpow(G, (MO - 1) / (len << 1)); 45 if(f == -1) Wn = qpow(Wn, MO - 2); 46 for(int i = 0; i < n; i += (len << 1)) { 47 LL w = 1; 48 for(int j = 0; j < len; j++) { 49 LL t = a[i + len + j] * w % MO; 50 a[i + len + j] = (a[i + j] - t) % MO; 51 a[i + j] = (a[i + j] + t) % MO; 52 w = w * Wn % MO; 53 } 54 } 55 } 56 if(f == -1) { 57 LL inv = qpow(n, MO - 2); 58 for(int i = 0; i < n; i++) a[i] = a[i] * inv % MO; 59 } 60 return; 61 } 62 63 inline void mul(const LL *a, const LL *b, LL *c, int n) { 64 int len = 1; 65 while(len < n + n) len <<= 1; 66 memcpy(A, a, n * sizeof(LL)); memset(A + n, 0, (len - n) * sizeof(LL)); 67 memcpy(B, b, n * sizeof(LL)); memset(B + n, 0, (len - n) * sizeof(LL)); 68 NTT(A, len, 1); NTT(B, len, 1); 69 for(int i = 0; i < len; i++) c[i] = A[i] * B[i] % MO; 70 NTT(c, len, -1); 71 memset(c + n, 0, (len - n) * sizeof(LL)); 72 return; 73 } 74 75 inline Poly mul(const Poly &a, const Poly &b) { 76 int len = 1, lena = a.size(), lenb = b.size(); 77 while(len < lena + lenb) len <<= 1; 78 Poly ans(lena + lenb - 1); 79 for(int i = 0; i < lena; i++) A[i] = a[i]; 80 for(int i = 0; i < lenb; i++) B[i] = b[i]; 81 memset(A + lena, 0, (len - lena) * sizeof(LL)); 82 memset(B + lenb, 0, (len - lenb) * sizeof(LL)); 83 NTT(A, len, 1); NTT(B, len, 1); 84 for(int i = 0; i < len; i++) A[i] = A[i] * B[i] % MO; 85 NTT(A, len, -1); 86 for(int i = 0; i < lena + lenb - 1; i++) ans[i] = A[i]; 87 return ans; 88 } 89 90 void Inv(const LL *a, LL *b, int n) { 91 if(n == 1) { 92 b[0] = qpow(a[0], MO - 2); 93 b[1] = 0; 94 return; 95 } 96 Inv(a, b, n >> 1); 97 /// ans = b * (2 - a * b); 98 memcpy(A, a, n * sizeof(LL)); memset(A + n, 0, n * sizeof(LL)); 99 memcpy(B, b, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL)); 100 NTT(A, n << 1, 1); NTT(B, n << 1, 1); 101 for(int i = 0; i < (n << 1); i++) b[i] = B[i] * (2 - A[i] * B[i] % MO) % MO; 102 NTT(b, n << 1, -1); 103 memset(b + n, 0, n * sizeof(LL)); 104 return; 105 } 106 107 inline void getInv(const LL *a, LL *b, int n) { 108 int len = 1; 109 while(len < n) len <<= 1; 110 Inv(a, b, len); 111 memset(b + n, 0, (len - n) * sizeof(LL)); 112 return; 113 } 114 115 inline Poly getInv(const Poly &a) { 116 int n = a.size(), len = 1; 117 while(len < n) len <<= 1; 118 for(int i = 0; i < n; i++) inv_t[i] = a[i]; 119 memset(inv_t + n, 0, (len - n) * sizeof(LL)); 120 getInv(inv_t, inv_t2, n); 121 Poly ans(n); 122 for(int i = 0; i < n; i++) ans[i] = inv_t2[i]; 123 return ans; 124 } 125 126 inline void der(const LL *a, LL *b, int n) { 127 for(int i = 0; i < n - 1; i++) { 128 b[i] = a[i + 1] * (i + 1) % MO; 129 } 130 b[n - 1] = 0; 131 return; 132 } 133 134 inline void ter(const LL *a, LL *b, int n) { 135 for(int i = n - 1; i >= 1; i--) { 136 b[i] = a[i - 1] * qpow(i, MO - 2) % MO; 137 } 138 b[0] = 0; 139 return; 140 } 141 142 inline void getLn(const LL *a, LL *b, int n) { 143 getInv(a, inv_t, n); 144 der(a, A, n); 145 int len = 1; 146 while(len < n + n) len <<= 1; 147 memset(A + n, 0, (len - n) * sizeof(LL)); 148 memcpy(B, inv_t, n * sizeof(LL)); memset(B + n, 0, (len - n) * sizeof(LL)); 149 NTT(A, len, 1); NTT(B, len, 1); 150 for(int i = 0; i < len; i++) b[i] = A[i] * B[i] % MO; 151 NTT(b, len, -1); 152 memset(b + n, 0, (len - n) * sizeof(LL)); 153 ter(b, b, n); 154 return; 155 } 156 157 void Exp(const LL *a, LL *b, int n) { 158 if(n == 1) { 159 b[0] = 1; 160 b[1] = 0; 161 return; 162 } 163 Exp(a, b, n >> 1); 164 /// ans = b * (1 + a - ln b) 165 getLn(b, exp_t, n); 166 for(int i = 0; i < n; i++) A[i] = (a[i] - exp_t[i]) % MO; 167 A[0] = (A[0] + 1) % MO; 168 memset(A + n, 0, n * sizeof(LL)); 169 memcpy(B, b, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL)); 170 NTT(A, n << 1, 1); NTT(B, n << 1, 1); 171 for(int i = 0; i < (n << 1); i++) b[i] = A[i] * B[i] % MO; 172 NTT(b, n << 1, -1); 173 memset(b + n, 0, n * sizeof(LL)); 174 return; 175 } 176 177 inline void getExp(const LL *a, LL *b, int n) { 178 int len = 1; 179 while(len < n) len <<= 1; 180 Exp(a, b, len); 181 memset(b + n, 0, (len - n) * sizeof(LL)); 182 return; 183 } 184 185 inline void out(const Poly &a) { 186 //printf("siz = %d ", a.size()); 187 for(int i = 0; i < a.size(); i++) { 188 printf("%lld ", (a[i] + MO) % MO); 189 } 190 printf("\n"); 191 return; 192 } 193 194 Poly dvd(int l, int r) { 195 if(l == r) { 196 Poly res(2); 197 res[0] = 1; res[1] = -a[r]; 198 return res; 199 } 200 int mid = (l + r) >> 1; 201 Poly ans = mul(dvd(l, mid), dvd(mid + 1, r)); 202 return ans; 203 } 204 205 inline void solve1() { 206 Poly q = dvd(1, n); 207 Poly p(n); 208 for(int i = 0; i < n; i++) { 209 p[i] = q[i] * (n - i) % MO; 210 } 211 p = mul(p, getInv(q)); 212 for(int i = 0; i < n; i++) { 213 ex[i] = p[i]; 214 //printf("ex %d = %lld \n", i, ex[i]); 215 } 216 return; 217 } 218 219 int main() { 220 221 scanf("%d%lld", &n, &m); 222 for(int i = 1; i <= n; i++) { 223 scanf("%lld", &a[i]); 224 } 225 /// 226 solve1(); 227 228 pw[0] = 1; 229 for(int i = 1; i <= n; i++) pw[i] = pw[i - 1] * i % MO; 230 for(int i = 0; i < n; i++) { 231 f[i] = qpow(i + 1, m) * qpow(pw[i], MO - 2) % MO; 232 h[i] = qpow(i + 1, m << 1) * qpow(pw[i], MO - 2) % MO; 233 } 234 getInv(f, p, n); 235 mul(p, h, h, n); 236 237 getLn(f, g, n); 238 for(int i = 0; i < n; i++) { 239 g[i] = g[i] * ex[i] % MO; 240 h[i] = h[i] * ex[i] % MO; 241 } 242 getExp(g, f, n); 243 244 mul(h, f, f, n); 245 246 LL ans = f[n - 2]; 247 for(int i = 1; i < n - 1; i++) ans = ans * i % MO; 248 for(int i = 1; i <= n; i++) ans = ans * a[i] % MO; 249 250 if(ans < 0) ans += MO; 251 printf("%lld\n", ans); 252 253 return 0; 254 }