牛客练习赛71 E.神奇的迷宫 (分治,NTT优化,dfs)

题目:传送门

题意

思路

若能求得取得两点距离为 i 的总概率,那么就可以直接 o(n) 得到答案了。

问题转化为求树上所有距离为 i (i:0~n-1) 的点对的概率和。

考虑用分治,每次,找一颗树的根,使得根的儿子中最深的深度尽可能的小。

然后,遍历根的所有儿子,每次,算出儿子到根的距离,维护概率和,然后将得到的概率和B 与 已经算过的儿子的概率和A,做一次NTT,然后更新答案即可。

算完之后,分治算子数的贡献。

#include <bits/stdc++.h>
#define LL long long
#define ULL unsigned long long
#define UI unsigned int
#define mem(i, j) memset(i, j, sizeof(i))
#define rep(i, j, k) for(int i = j; i <= k; i++)
#define dep(i, j, k) for(int i = k; i >= j; i--)
#define pb push_back
#define make make_pair
#define INF 0x3f3f3f3f
#define inf LLONG_MAX
#define PI acos(-1)
#define fir first
#define sec second
#define lb(x) ((x) & (-(x)))
#define dbg(x) cout<<#x<<" = "<<x<<endl;
using namespace std;

///                   NTT.start              ///
const int N = 1e6 + 5;
const LL G = 3;
const LL mod = 998244353;
int len, r[N];
LL x[N], y[N], w[N];
LL ksm(LL x,LL y) {
    LL ans = 1;
    while(y) {
        if( y & 1 ) ans = ans * x % mod;
        x = x * x % mod;
        y >>= 1;
    }
    return ans;
}
void ntt(LL *a, LL f) {
    for (LL i = 0; i < len; i++) {
        if (i < r[i]) swap(a[i], a[r[i]]);
    }
    w[0] = 1;
    for (LL i = 2; i <= len; i *= 2) {
        LL wn;
        if (f == 1) wn = ksm(G, (LL)(mod - 1) / i);
        else wn = ksm(G, (LL)(mod - 1) - (mod - 1) / i);
        for (LL j = i / 2; j >= 0; j -= 2) w[j] = w[j / 2];
        for (LL j = 1; j < i / 2; j += 2) w[j] = (w[j - 1] * wn) % mod;
        for (LL j = 0; j < len; j += i) {
            for (LL k = 0 ; k < i / 2; k++) {
                LL u = a[j + k], v = (a[j + k + i / 2] * w[k]) % mod;
                a[j + k] = (u + v) % mod;
                a[j + k + i / 2] = (u - v + mod) % mod;
            }
        }
    }
    if (f == -1) {
        LL inv = ksm(len, mod - 2);
        for (LL i = 0; i < len; i++) a[i] = (a[i] * inv) % mod;
    }
}
void NTT(LL *a, LL *b, LL *c, LL n, LL m) {
    len = 1;
    while (len <= (n + m)) len *= 2;
    int k = trunc(log(len + 0.5) / log(2));
    for (int i = 0; i < len; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
    }
    for (int i = 0; i < len; i++) {
        if (i < n) x[i] = a[i]; else x[i] = 0;
        if (i < m) y[i] = b[i]; else y[i] = 0;
    }
    ntt(x, 1); ntt(y, 1);
    for (LL i = 0; i < len; i++) c[i] = x[i] * y[i] % mod;
    ntt(c, -1);
}

///                   NTT.end              ///

LL a[N], cost[N], ans[N], A[N], B[N], C[N];

vector < int > Q[N];

int rt, n, sz[N], ma[N], depth[N], ma_depth[N];

bool vis[N];

void get_root(int u, int fa, int all) { /// 找树的根,尽可能的让儿子的sz(大小)不要差太多

    sz[u] = 1; ma[u] = 0;

    for(auto v : Q[u]) {

        if(vis[v] || v == fa) continue;

        get_root(v, u, all);

        sz[u] += sz[v];

        ma[u] = max(ma[u], sz[v]);

    }

    ma[u] = max(ma[u], all - ma[u]);

    if(ma[u] < ma[rt]) rt = u;

}

void dfs(int u, int fa) { ///遍历儿子,维护儿子的深度和大小

    sz[u] = 1; ma_depth[u] = depth[u];

    for(auto v : Q[u]) {

        if(vis[v] || v == fa) continue;

        depth[v] = depth[u] + 1;

        dfs(v, u);

        sz[u] += sz[v];

        ma_depth[u] = max(ma_depth[u], ma_depth[v]);

    }

}

void GO(int u, int fa) { /// 更新一下 B数组, B[i] 表示取到深度为 i 的概率

    B[depth[u]] = (B[depth[u]] + a[u]) % mod;

    for(auto v : Q[u]) {

        if(vis[v] || v == fa) continue;

        GO(v, u);

    }

}

void cal(int u) {

    vis[u] = 1;

    int L = 0;

    for(auto v : Q[u]) {

        if(vis[v]) continue;

        depth[v] = 1;

        dfs(v, u);

    }

    A[0] = a[u]; /// A[i] 是当前已经遍历过的子树中,取到深度为 i 的概率

    for(auto v : Q[u]) { /// 遍历儿子

        if(vis[v]) continue;

        GO(v, u); /// 用以当前儿子为根的子树去更新B[i]

        NTT(A, B, C, L + 1, ma_depth[v] + 1); /// NTT 做 A * B = C, 用这个模版,长度要加1

        int Len = L + ma_depth[v] + 2; 

        rep(i, 1, len) ans[i] = (ans[i] + C[i]) % mod; /// 更新答案

        rep(i, 0, ma_depth[v]) A[i] = (A[i] + B[i]) % mod, B[i] = 0LL; /// 将当前 B[i] 更新到 A[i] 去

        rep(i, 0, len) C[i] = 0LL; 

        L = max(L, ma_depth[v]); /// 维护一下A的长度

    }

    rep(i, 0, L) A[i] = 0LL;

    for(auto v : Q[u]) { /// 分治算儿子的贡献

        if(vis[v]) continue;

        rt = 0;

        get_root(v, u, sz[v]);

        cal(rt);

    }

}

void solve() {
    
    ///若跳到同个点上,则只需算一次,否则(a,b)和(b,a)各算一次,需要乘2,这里先求得跳到同个点的,再分治长度大于0的贡献。
    
    scanf("%d", &n);

    LL s = 0LL;

    rep(i, 1, n) scanf("%lld", &a[i]), s += a[i];

    s = ksm(s, mod - 2);

    LL res = 0LL;

    rep(i, 1, n) a[i] = a[i] * s % mod, res = (res + a[i] * a[i] % mod) % mod;

    rep(i, 0, n - 1) scanf("%lld", &cost[i]);

    res = res * cost[0] % mod;

    rep(i, 1, n - 1) {

        int u, v;

        scanf("%d %d", &u, &v);

        Q[u].pb(v); Q[v].pb(u);

    }

    rt = 0; ma[0] = N;

    get_root(1, 0, n);

    cal(rt);

    rep(i, 1, n - 1) res = (res + 2LL * ans[i] * cost[i] % mod) % mod;

    printf("%lld\n", res);

}


int main() {

//    int _; scanf("%d", &_);
//    while(_--) solve();

    solve();

    return 0;
}

 

posted on 2020-10-10 15:43  Willems  阅读(157)  评论(0编辑  收藏  举报

导航