牛客练习赛71 E- 神奇的迷宫 点分治+NTT
牛客练习赛71 E- 神奇的迷宫
题意
给一颗\(n\)个点的树,每条边的长度均为\(1\),Alice和Bob两人依次传送到树的某两个结点。对于任意一个人,传送到点\(i\)的概率为\(p_i\),假设两人传送到的结点之间的最短距离为\(L\),那么他们挑战这个树的困难度为\(w_i\)。
问他们挑战这个树的困难度的期望是多少。
\(n\le 10^5\)
分析
令\(ans[i]\)表示两人最短距离为\(i\)的概率,答案即为\(\sum_{i=0}^{n-1}ans[i]\cdot w[i]\)。
求\(ans[i]\)可以用点分治来做,以\(u\)作为分治中心时,枚举每个子树,用\(A[i]\)表示已经枚举过的子树中到根的距离为\(i\)的点的概率之和,用\(B[i]\)表示当前子树中到根的距离为\(i\)的点的概率之和,那么就可以更新\(ans[k]+=\sum_{i=0}^{k}A[i]\cdot B[k-i]\),注意到这是一个卷积形式,所以我们对\(A,B\)做一次卷积就能更新\(ans[i]\),因为答案要取模,所以用\(NTT\)来做卷积。
复杂度为\(O(nlog^2n)\)。
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e6+10;
const int mod = 998244353, G = 3, Gi = 332748118;
int n;
ll p[N],w[N];
vector<int>g[N];
int sz[N],vis[N],mx[N],rt,tot,k1,k2;
ll ans,A[N],B[N];
int limit = 1, L, r[N];
ll a[N], b[N];
ll ksm(ll a, ll b) {
ll ret = 1;
while(b) {
if(b & 1) ret = (ret * a ) % mod;
a = (a * a) % mod;
b >>= 1;
}
return ret;
}
void NTT(ll *A, int type) {
for(int i = 0; i < limit; i++)
if(i < r[i]) swap(A[i], A[r[i]]);
for(int mid = 1; mid < limit; mid <<= 1) {
ll Wn = ksm( type == 1 ? G : Gi , (mod - 1) / (mid << 1));
for(int j = 0; j < limit; j += (mid << 1)) {
ll w = 1;
for(int k = 0; k < mid; k++, w = (w * Wn) % mod) {
int x = A[j + k], y = w * A[j + k + mid] % mod;
A[j + k] = (x + y) % mod,
A[j + k + mid] = (x - y + mod) % mod;
}
}
}
}
void gao() {
limit=1,L=0;
for(int i=0;i<=k1;i++) a[i]=A[i];
for(int i=0;i<=k2;i++) b[i]=B[i];
while(limit <= k1 + k2) limit <<= 1, L++;
for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(a, 1);NTT(b, 1);
for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % mod;
NTT(a, -1);
ll inv = ksm(limit, mod - 2);
for(int i=0;i<=k1+k2;i++) a[i]=(a[i]*inv)%mod;
for(int i = 0; i <= k1 + k2&&i<n; i++){
ans=(ans+a[i] * w[i]%mod*2%mod)%mod;
}
for(int i=0;i<=limit;i++) a[i]=b[i]=r[i]=0;
}
void getrt(int u,int fa){
sz[u]=1,mx[u]=0;
for(int x:g[u]){
if(x==fa||vis[x]) continue;
getrt(x,u);
sz[u]+=sz[x];
mx[u]=max(mx[u],sz[x]);
}
mx[u]=max(mx[u],tot-sz[u]);
if(mx[u]<mx[rt]) rt=u;
}
void dfs(int u,int fa,int d){
B[d]=(B[d]+p[u])%mod;
k2=max(k2,d);
for(int x:g[u]){
if(x==fa||vis[x]) continue;
dfs(x,u,d+1);
}
}
void solve(int u){
vis[u]=1;k1=k2=0;
A[0]=p[u];
for(int x:g[u]){
if(vis[x]) continue;
k2=0;
dfs(x,u,1);
/*
for(int i=0;i<=k1;i++){
for(int j=0;j<=k2;j++) if(i+j<n){
ans+=w[i+j]*A[i]%mod*B[j]%mod*2%mod;
ans%=mod;
}
}
*/
gao();
k1=max(k1,k2);
for(int i=0;i<=k2;i++) A[i]=(A[i]+B[i])%mod,B[i]=0;
}
for(int i=0;i<=k1;i++) A[i]=0;
for(int x:g[u]){
if(vis[x]) continue;
tot=sz[x],mx[rt=0]=n;
getrt(x,0);
solve(rt);
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%lld",&p[i]);
p[0]+=p[i];
if(p[0]>=mod) p[0]-=mod;
}
p[0]=ksm(p[0],mod-2);
for(int i=1;i<=n;i++){
scanf("%lld",&w[i-1]);
p[i]=p[i]*p[0]%mod;
}
for(int i=2,x,y;i<=n;i++){
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
tot=mx[rt]=n;
getrt(1,0);
solve(rt);
for(int i=1;i<=n;i++) ans=(ans+w[0]*p[i]%mod*p[i]%mod)%mod;
printf("%lld\n",ans);
return 0;
}