【纪中集训2019.3.12】Mas的仙人掌
题意:
给出一棵\(n\)个点的树,需要加\(m\)条边,每条边脱落的概率为\(p_{i}\) ,求加入的边在最后形成图中仅在一个简单环上的边数的期望;
\(1 \le n \ , m \le 10^6\)
题解:
-
考虑每一条边的贡献是\((1-p_{i})*\Pi_{j}p_{j}(j!=i)\),这里\(j\)和\(i\)不能同时加入;
-
一条加入的边可以看成一条树上路径 ,即求所有和路径\(i\)相交的路径\(j\)的\(p_{j}\)的乘积;
-
将一条树上的链\((u,v)\)拆成两条\((u,lca)\)和\((v,lca)\);
-
这样会算重复\(i\)和\(j\)的两条链都相交的情况,但是这样\((u,lca)\)上\(lca\)的儿子是唯一的,\(v\)同理,\(hash\)去掉多算的部分;
-
考虑一条路径\((u,v) dep[u]<dep[v]\),考虑一个点的子树里向上有长度为\(2\)的链的个数,简单容斥预处理出在\(v\)的子树里但不在\((u,v)\)上的点对答案的贡献;
-
在\((u,v)\)上的点直接每条链在最下端打标记即可;
-
时间复杂度:O(\(nlog \ n\))
#include<bits/stdc++.h> #define ll long long #define mod 998244353 using namespace std; const int N=1000010; int n,m,o=1,hd[N],a[N],b[N],ia[N],ans[N]; int fa[N][21],bin[21],dep[N],cnt,tot; int s1[N],s2[N],s3[N]; ll s4[N]; struct Edge{int v,nt;}E[N<<1]; struct data{ int u,v,w,id; }A[N],B[N<<1]; int pw(int x,int y){ int re=1; while(y){ if(y&1)re=(ll)re*x%mod; y>>=1;x=(ll)x*x%mod; } return re; }// void adde(int u,int v){ E[o]=(Edge){v,hd[u]};hd[u]=o++; E[o]=(Edge){u,hd[v]};hd[v]=o++; }// char gc(){ static char*p1,*p2,s[1000000]; if(p1==p2)p2=(p1=s)+fread(s,1,1000000,stdin); return(p1==p2)?EOF:*p1++; }// int rd(){ int x=0;char c=gc(); while(c<'0'||c>'9')c=gc(); while(c>='0'&&c<='9')x=(x<<1)+(x<<3)+c-'0',c=gc(); return x; }// const int sz=1234651; struct HASH{ int o,U[N],V[N],w[N],hd[sz],nt[N]; int ask(int u,int v){ if(u>v)swap(u,v); int x = ((ll)u*mod+v)%sz; for(int i=hd[x];i;i=nt[i]){ if(U[i]==u&&V[i]==v)return w[i]; } return 1; } void upd(int u,int v,int y){ if(u>v)swap(u,v); int x = ((ll)u*mod+v)%sz; for(int i=hd[x];i;i=nt[i]){ if(U[i]==u&&V[i]==v){w[i]=(ll)w[i]*y%mod;return;} } nt[++o]=hd[x];hd[x]=o;U[o]=u,V[o]=v;w[o]=y; } }H;// void dfs1(int u,int F){ dep[u]=dep[fa[u][0]=F]+1; for(int i=1;bin[i]<dep[u];++i)fa[u][i]=fa[fa[u][i-1]][i-1]; for(int i=hd[u];i;i=E[i].nt){ int v=E[i].v; if(v==F)continue; dfs1(v,u); } }// int go(int u,int d){ for(int i=0;i<20&&d;++i)if(d&bin[i])d^=bin[i],u=fa[u][i]; return u; }// int lca(int u,int v){ if(dep[u]<dep[v])swap(u,v); u=go(u,dep[u]-dep[v]); if(u==v)return u; for(int i=19;~i;--i)if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i]; return fa[u][0]; }// void dfs2(int u,int F){ for(int i=hd[u];i;i=E[i].nt){ int v=E[i].v; if(v==F)continue; dfs2(v,u); s2[u]=(ll)s2[u]*s2[v]%mod; s3[u]=(ll)s3[u]*s2[v]%mod; s4[u]=(s4[u]+s4[v]); } }// void dfs3(int u,int F){ for(int i=hd[u];i;i=E[i].nt){ int v=E[i].v; if(v==F)continue; s1[v]=(ll)s1[u]*s1[v]%mod; s2[v]=(ll)s2[u]*s2[v]%mod; s3[v]=(ll)s3[u]*s3[v]%mod; s4[v]=(s4[u]+s4[v]); dfs3(v,u); } }// inline int cal1(int u,int v){return (ll)s1[u]*pw(s1[v],mod-2)%mod;} inline int cal2(int u,int v){return (ll)s2[u]*pw(s2[v],mod-2)%mod;} inline int cal3(int u,int v){return (ll)s3[u]*pw(s3[v],mod-2)%mod;} inline ll cal4(int u,int v){return s4[u]-s4[v];} // int main(){ freopen("cactus.in","r",stdin); freopen("cactus.out","w",stdout); n=rd();m=rd(); for(int i=bin[0]=1;i<=20;++i)bin[i]=bin[i-1]<<1; for(int i=1;i<n;++i)adde(rd(),rd()); dfs1(1,0); for(int i=1;i<=max(n,m);++i){ans[i]=s1[i]=s2[i]=s3[i]=1;} for(int i=1,u,v,w,p,q;i<=m;++i){ u=rd();v=rd();w=lca(u,v); a[i]=rd();ia[i]=pw(a[i],mod-2); b[i]=(1+mod-a[i])%mod; if(!a[i]){ s4[u]++;s4[v]++;s4[w]-=2; if(u!=w){ p=go(u,dep[u]-dep[w]-1); B[++cnt]=(data){u,w,p,i}; } if(v!=w){ q=go(v,dep[v]-dep[w]-1); B[++cnt]=(data){v,w,q,i}; } if(u!=w&&v!=w){ A[++tot]=(data){p,q,0,i}; } continue; }// if(u!=w){ p=go(u,dep[u]-dep[w]-1); B[++cnt]=(data){u,w,p,i}; s1[u]=(ll)s1[u]*a[i]%mod; s2[u]=(ll)s2[u]*a[i]%mod; s2[p]=(ll)s2[p]*ia[i]%mod; } if(v!=w){ q=go(v,dep[v]-dep[w]-1); B[++cnt]=(data){v,w,q,i}; s1[v]=(ll)s1[v]*a[i]%mod; s2[v]=(ll)s2[v]*a[i]%mod; s2[q]=(ll)s2[q]*ia[i]%mod; } if(u!=w&&v!=w){ A[++tot]=(data){p,q,0,i}; H.upd(p,q,ia[i]); } } dfs2(1,0); dfs3(1,0); for(int i=1;i<=tot;++i){ int u=A[i].u,v=A[i].v; ans[A[i].id]=(ll)ans[A[i].id]*H.ask(u,v)%mod; } for(int i=1;i<=cnt;++i){ int u=B[i].u,v=B[i].v,w=B[i].w,now=1; now=cal3(u,v); now=(ll)now*pw(cal2(u,w),mod-2)%mod; now=(ll)now*cal1(u,v)%mod; ans[B[i].id]=(ll)ans[B[i].id]*now%mod; int t = cal4(u,v); if(!a[B[i].id])t-=dep[u]-dep[v]; if(t)ans[B[i].id]=0; } int Ans=0; for(int i=1;i<=m;++i){ ans[i] = (ll)ans[i]*b[i]%mod; if(a[i])ans[i] = (ll)ans[i]*pw(a[i],mod-2)%mod; Ans=(Ans+ans[i])%mod; } cout<<Ans<<endl; return 0; }