牛客练习赛71 E.神奇的迷宫 (分治,NTT优化,dfs)
题目:传送门
题意
思路
若能求得取得两点距离为 i 的总概率,那么就可以直接 o(n) 得到答案了。
问题转化为求树上所有距离为 i (i:0~n-1) 的点对的概率和。
考虑用分治,每次,找一颗树的根,使得根的儿子中最深的深度尽可能的小。
然后,遍历根的所有儿子,每次,算出儿子到根的距离,维护概率和,然后将得到的概率和B 与 已经算过的儿子的概率和A,做一次NTT,然后更新答案即可。
算完之后,分治算子数的贡献。
#include <bits/stdc++.h> #define LL long long #define ULL unsigned long long #define UI unsigned int #define mem(i, j) memset(i, j, sizeof(i)) #define rep(i, j, k) for(int i = j; i <= k; i++) #define dep(i, j, k) for(int i = k; i >= j; i--) #define pb push_back #define make make_pair #define INF 0x3f3f3f3f #define inf LLONG_MAX #define PI acos(-1) #define fir first #define sec second #define lb(x) ((x) & (-(x))) #define dbg(x) cout<<#x<<" = "<<x<<endl; using namespace std; /// NTT.start /// const int N = 1e6 + 5; const LL G = 3; const LL mod = 998244353; int len, r[N]; LL x[N], y[N], w[N]; LL ksm(LL x,LL y) { LL ans = 1; while(y) { if( y & 1 ) ans = ans * x % mod; x = x * x % mod; y >>= 1; } return ans; } void ntt(LL *a, LL f) { for (LL i = 0; i < len; i++) { if (i < r[i]) swap(a[i], a[r[i]]); } w[0] = 1; for (LL i = 2; i <= len; i *= 2) { LL wn; if (f == 1) wn = ksm(G, (LL)(mod - 1) / i); else wn = ksm(G, (LL)(mod - 1) - (mod - 1) / i); for (LL j = i / 2; j >= 0; j -= 2) w[j] = w[j / 2]; for (LL j = 1; j < i / 2; j += 2) w[j] = (w[j - 1] * wn) % mod; for (LL j = 0; j < len; j += i) { for (LL k = 0 ; k < i / 2; k++) { LL u = a[j + k], v = (a[j + k + i / 2] * w[k]) % mod; a[j + k] = (u + v) % mod; a[j + k + i / 2] = (u - v + mod) % mod; } } } if (f == -1) { LL inv = ksm(len, mod - 2); for (LL i = 0; i < len; i++) a[i] = (a[i] * inv) % mod; } } void NTT(LL *a, LL *b, LL *c, LL n, LL m) { len = 1; while (len <= (n + m)) len *= 2; int k = trunc(log(len + 0.5) / log(2)); for (int i = 0; i < len; i++) { r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1)); } for (int i = 0; i < len; i++) { if (i < n) x[i] = a[i]; else x[i] = 0; if (i < m) y[i] = b[i]; else y[i] = 0; } ntt(x, 1); ntt(y, 1); for (LL i = 0; i < len; i++) c[i] = x[i] * y[i] % mod; ntt(c, -1); } /// NTT.end /// LL a[N], cost[N], ans[N], A[N], B[N], C[N]; vector < int > Q[N]; int rt, n, sz[N], ma[N], depth[N], ma_depth[N]; bool vis[N]; void get_root(int u, int fa, int all) { /// 找树的根,尽可能的让儿子的sz(大小)不要差太多 sz[u] = 1; ma[u] = 0; for(auto v : Q[u]) { if(vis[v] || v == fa) continue; get_root(v, u, all); sz[u] += sz[v]; ma[u] = max(ma[u], sz[v]); } ma[u] = max(ma[u], all - ma[u]); if(ma[u] < ma[rt]) rt = u; } void dfs(int u, int fa) { ///遍历儿子,维护儿子的深度和大小 sz[u] = 1; ma_depth[u] = depth[u]; for(auto v : Q[u]) { if(vis[v] || v == fa) continue; depth[v] = depth[u] + 1; dfs(v, u); sz[u] += sz[v]; ma_depth[u] = max(ma_depth[u], ma_depth[v]); } } void GO(int u, int fa) { /// 更新一下 B数组, B[i] 表示取到深度为 i 的概率 B[depth[u]] = (B[depth[u]] + a[u]) % mod; for(auto v : Q[u]) { if(vis[v] || v == fa) continue; GO(v, u); } } void cal(int u) { vis[u] = 1; int L = 0; for(auto v : Q[u]) { if(vis[v]) continue; depth[v] = 1; dfs(v, u); } A[0] = a[u]; /// A[i] 是当前已经遍历过的子树中,取到深度为 i 的概率 for(auto v : Q[u]) { /// 遍历儿子 if(vis[v]) continue; GO(v, u); /// 用以当前儿子为根的子树去更新B[i] NTT(A, B, C, L + 1, ma_depth[v] + 1); /// NTT 做 A * B = C, 用这个模版,长度要加1 int Len = L + ma_depth[v] + 2; rep(i, 1, len) ans[i] = (ans[i] + C[i]) % mod; /// 更新答案 rep(i, 0, ma_depth[v]) A[i] = (A[i] + B[i]) % mod, B[i] = 0LL; /// 将当前 B[i] 更新到 A[i] 去 rep(i, 0, len) C[i] = 0LL; L = max(L, ma_depth[v]); /// 维护一下A的长度 } rep(i, 0, L) A[i] = 0LL; for(auto v : Q[u]) { /// 分治算儿子的贡献 if(vis[v]) continue; rt = 0; get_root(v, u, sz[v]); cal(rt); } } void solve() { ///若跳到同个点上,则只需算一次,否则(a,b)和(b,a)各算一次,需要乘2,这里先求得跳到同个点的,再分治长度大于0的贡献。 scanf("%d", &n); LL s = 0LL; rep(i, 1, n) scanf("%lld", &a[i]), s += a[i]; s = ksm(s, mod - 2); LL res = 0LL; rep(i, 1, n) a[i] = a[i] * s % mod, res = (res + a[i] * a[i] % mod) % mod; rep(i, 0, n - 1) scanf("%lld", &cost[i]); res = res * cost[0] % mod; rep(i, 1, n - 1) { int u, v; scanf("%d %d", &u, &v); Q[u].pb(v); Q[v].pb(u); } rt = 0; ma[0] = N; get_root(1, 0, n); cal(rt); rep(i, 1, n - 1) res = (res + 2LL * ans[i] * cost[i] % mod) % mod; printf("%lld\n", res); } int main() { // int _; scanf("%d", &_); // while(_--) solve(); solve(); return 0; }
一步一步,永不停息