牛客练习赛71 E- 神奇的迷宫 点分治+NTT

牛客练习赛71 E- 神奇的迷宫

题意

给一颗\(n\)个点的树,每条边的长度均为\(1\),Alice和Bob两人依次传送到树的某两个结点。对于任意一个人,传送到点\(i\)的概率为\(p_i\),假设两人传送到的结点之间的最短距离为\(L\),那么他们挑战这个树的困难度为\(w_i\)

问他们挑战这个树的困难度的期望是多少。

\(n\le 10^5\)

分析

\(ans[i]\)表示两人最短距离为\(i\)的概率,答案即为\(\sum_{i=0}^{n-1}ans[i]\cdot w[i]\)

\(ans[i]\)可以用点分治来做,以\(u\)作为分治中心时,枚举每个子树,用\(A[i]\)表示已经枚举过的子树中到根的距离为\(i\)的点的概率之和,用\(B[i]\)表示当前子树中到根的距离为\(i\)的点的概率之和,那么就可以更新\(ans[k]+=\sum_{i=0}^{k}A[i]\cdot B[k-i]\),注意到这是一个卷积形式,所以我们对\(A,B\)做一次卷积就能更新\(ans[i]\),因为答案要取模,所以用\(NTT\)来做卷积。

复杂度为\(O(nlog^2n)\)

Code

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int N=1e6+10;
const int mod = 998244353, G = 3, Gi = 332748118;
int n;
ll p[N],w[N];
vector<int>g[N];
int sz[N],vis[N],mx[N],rt,tot,k1,k2;
ll ans,A[N],B[N];
int limit = 1, L, r[N];
ll a[N], b[N];
ll ksm(ll a, ll b) {
    ll ret = 1;
    while(b) {
        if(b & 1) ret = (ret * a ) % mod;
        a = (a * a) % mod;
        b >>= 1;
    }
    return ret;
}
void NTT(ll *A, int type) {
    for(int i = 0; i < limit; i++)
        if(i < r[i]) swap(A[i], A[r[i]]);
    for(int mid = 1; mid < limit; mid <<= 1) {
        ll Wn = ksm( type == 1 ? G : Gi , (mod - 1) / (mid << 1));
        for(int j = 0; j < limit; j += (mid << 1)) {
            ll w = 1;
            for(int k = 0; k < mid; k++, w = (w * Wn) % mod) {
                 int x = A[j + k], y = w * A[j + k + mid] % mod;
                 A[j + k] = (x + y) % mod,
                 A[j + k + mid] = (x - y + mod) % mod;
            }
        }
    }
}
void gao() {
    limit=1,L=0;
    for(int i=0;i<=k1;i++) a[i]=A[i];
    for(int i=0;i<=k2;i++) b[i]=B[i];
    while(limit <= k1 + k2) limit <<= 1, L++;
    for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
    NTT(a, 1);NTT(b, 1);
    for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % mod;
    NTT(a, -1);
    ll inv = ksm(limit, mod - 2);
    for(int i=0;i<=k1+k2;i++) a[i]=(a[i]*inv)%mod;
    for(int i = 0; i <= k1 + k2&&i<n; i++){
        ans=(ans+a[i] * w[i]%mod*2%mod)%mod;
    }
    for(int i=0;i<=limit;i++) a[i]=b[i]=r[i]=0;
}
void getrt(int u,int fa){
    sz[u]=1,mx[u]=0;
    for(int x:g[u]){
        if(x==fa||vis[x]) continue;
        getrt(x,u);
        sz[u]+=sz[x];
        mx[u]=max(mx[u],sz[x]);
    }
    mx[u]=max(mx[u],tot-sz[u]);
    if(mx[u]<mx[rt]) rt=u;
}
void dfs(int u,int fa,int d){
    B[d]=(B[d]+p[u])%mod;
    k2=max(k2,d);
    for(int x:g[u]){
        if(x==fa||vis[x]) continue;
        dfs(x,u,d+1);
    }
}
void solve(int u){
    vis[u]=1;k1=k2=0;
    A[0]=p[u];
    for(int x:g[u]){
        if(vis[x]) continue;
        k2=0;
        dfs(x,u,1);
        /*
        for(int i=0;i<=k1;i++){
            for(int j=0;j<=k2;j++) if(i+j<n){
                ans+=w[i+j]*A[i]%mod*B[j]%mod*2%mod;
                ans%=mod;
            }
        }
        */
        gao();
        k1=max(k1,k2);
        for(int i=0;i<=k2;i++) A[i]=(A[i]+B[i])%mod,B[i]=0;
    }
    for(int i=0;i<=k1;i++) A[i]=0;
    for(int x:g[u]){
        if(vis[x]) continue;
        tot=sz[x],mx[rt=0]=n;
        getrt(x,0);
        solve(rt);
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%lld",&p[i]);
        p[0]+=p[i];
        if(p[0]>=mod) p[0]-=mod;
    }
    p[0]=ksm(p[0],mod-2);
    for(int i=1;i<=n;i++){
        scanf("%lld",&w[i-1]);
        p[i]=p[i]*p[0]%mod;
    }
    for(int i=2,x,y;i<=n;i++){
        scanf("%d%d",&x,&y);
        g[x].push_back(y);
        g[y].push_back(x);
    }
    tot=mx[rt]=n;
    getrt(1,0);
    solve(rt);
    for(int i=1;i<=n;i++) ans=(ans+w[0]*p[i]%mod*p[i]%mod)%mod;
    printf("%lld\n",ans);
    return 0;
}
posted @ 2020-10-13 16:21  xyq0220  阅读(97)  评论(0编辑  收藏  举报