dtoj#4317. 随机(random)
题目描述:
有一棵$N$个点的树和$M$个操作$x,y,S$,表示给$x$到$y$的链上的节点都加入一个数字字符串$S$。
所有操作都结束后需要对每个点进行一次询问。首先在该节点的所有字符串形成的`trie`上随机选择一个点(此时为步数$0$),这里定义`trie`的根节点为空串,然后每次随机一个与当前点相邻的点移动过去且步数$+1$,如果移动到了某个字符串对应的节点则结束。询问每个点对应`trie`的期望步数。
输出模$998244353$意义下的值。假如答案为$\frac{x}{y}(gcd(x,y)=1)$,那么需要输出$r(0≤r<998244353)$满足$x \equiv y×r(mod~998244353)$。
数据保证答案不会存在$y \equiv 0(mod~998244353)$的情况。
思路:
考虑对于一棵固定的 $trie$ 树,一个点走到终止点的期望会是
$$
f[x]=\frac{1}{d[x]}\times((\sum_{son}f[to])+f[fa])+1
$$
那么如果我们对于一颗 $trie$ 树 $dfs$ 时,对于孩子节点的情况已经求得了,我们把 $f[fax]$ 看作未知数, $f[x]$ 可以表示成一个关于 $f[fax]$ 的一次函数。
$$
f[x]=A_x\times f[fax]+B_x
$$
考虑如何确定每一个值的 $A_x$ 和 $B_x$ ,访问到 $x$ 对于每一个孩子的情况已经确定,即:
$$
f[to]=A_{to}\times f[x]+B_{to}
$$
那么:
$$
f[x]=\frac{1}{d[x]}\times((\sum_{son}A_{to}f[x]+B_{to})+f[fax])+1
$$
整理一下得到:
$$
f[x]=\frac{1}{d[x]-\sum_{son}A_{to}}f[fax]+\frac{(\sum_{son}B_{to})+d[x]}{d[x]-\sum_{son}A_{to}}
$$
容易知道最后对于一棵 $trie$ 树的答案是
$$
Ans=\frac{\sum_{i=1}^{cnt} f[x]}{cnt}
$$
( $cnt$ 表示 $trie$ 树的点数 )
所以要考虑维护整个子树的和
我们再用一样的方法表示出子树的和 $g[x]$ :
$$
g[x]=f[x]+\sum_{son}g[to]
$$
同理 $g[x]$ 可以表示成关于 $ f[fax] $ 的一次函数:
$$
g[to]=C_xf[fax]+D_x
$$
$$
g[x]=f[x]+\sum_{son}(C_{x}f[x]+D_{x})
$$
$$
g[x]=A_xf[fax]+Bx+\sum_{son}(C_{to}(A_xf[fax]+B_x)+D_{to})
$$
$$
g[x]=((1+\sum_{son}C_{to})A_{x})f[fax]+(\sum_{son}C_{to}+1)B_x+\sum_{son}D_{to}
$$
对于树上结点的删除与添加在 $trie$ 树上动态修改
以下代码:
#include<bits/stdc++.h> #define il inline #define pb push_back #define LL long long #define _(d) while(d(isdigit(ch=getchar()))) using namespace std; const int N=3e5+5,M=2e6+5,p=998244353; char s[M],c[M]; int n,head[N],ne[N<<1],to[N<<1],cnt,fa[N][21],d[N],res[N],be[N]; int sz[M],rt[N],A[M],B[M],C[M],D[M],ch[M][12],num[M],tag[M],m; struct node{int x,c;}; vector<node> v[N]; il int read(){ int x,f=1;char ch; _(!)ch=='-'?f=-1:f;x=ch^48; _()x=(x<<1)+(x<<3)+(ch^48); return f*x; } il int mu(int x,int y){ return x+y>=p?x+y-p:x+y; } il int ksm(LL a,int y){ LL b=1; while(y){ if(y&1)b=b*a%p; a=a*a%p;y>>=1; } return b; } il void ins(int x,int y){ ne[++cnt]=head[x]; head[x]=cnt;to[cnt]=y; } il void dfs1(int x){ for(int i=1;fa[x][i-1];i++)fa[x][i]=fa[fa[x][i-1]][i-1]; for(int i=head[x];i;i=ne[i]){ if(fa[x][0]==to[i])continue; fa[to[i]][0]=x; d[to[i]]=d[x]+1;dfs1(to[i]); } } il int Lca(int x,int y){ if(d[x]<d[y])swap(x,y); for(int i=20;i>=0;i--)if(d[fa[x][i]]>=d[y])x=fa[x][i]; if(x==y)return x; for(int i=20;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i]; return fa[x][0]; } il void update(int x,int f=0){ int Sa=0,Sb=0,Sc=0,Sd=0,d=0; A[x]=B[x]=C[x]=D[x]=0;d=f^1; if(!num[x]){sz[x]=0;return;}sz[x]=1; for(int i=0,to;i<=9;i++)if(sz[to=ch[x][i]]>0){ d++;sz[x]+=sz[to]; Sa=mu(Sa,A[to]);Sb=mu(Sb,B[to]); Sc=mu(Sc,C[to]);Sd=mu(Sd,D[to]); } if(tag[x]){D[x]=Sd;return;} A[x]=ksm(mu(d,p-Sa),p-2);B[x]=1ll*A[x]*mu(Sb,d)%p; C[x]=1ll*A[x]*mu(Sc,1)%p;D[x]=mu(1ll*B[x]*mu(Sc,1)%p,Sd); } il void add(int &x,int id,int l){ if(!x)x=++cnt;num[x]++; if(l>=be[id+1]){tag[x]++;update(x);return;} add(ch[x][s[l]-'0'],id,l+1);update(x); } il void del(int x,int id,int l){ num[x]--; if(l>=be[id+1]){tag[x]--;update(x);return;} del(ch[x][s[l]-'0'],id,l+1);update(x); } il void merge(int &x,int y){ if(!num[x]||!num[y]){x=(num[x]?x:y);return;} num[x]+=num[y];tag[x]+=tag[y]; for(int i=0;i<=9;i++)merge(ch[x][i],ch[y][i]); update(x); } il int query(int x){ update(x,1); return 1ll*D[x]*ksm(sz[x],p-2)%p; } il void dfs(int x,int fa){ int son=0; for(int i=head[x];i;i=ne[i]){ if(fa^to[i]){ dfs(to[i],x);son=to[i]; } } if(son)rt[x]=rt[son]; for(int i=head[x];i;i=ne[i]) if(fa^to[i]&&to[i]^son)merge(rt[x],rt[to[i]]); for(int i=0;i<v[x].size();i++){ node k=v[x][i]; if(k.c>0)add(rt[x],k.x,be[k.x]); else del(rt[x],k.x,be[k.x]); } res[x]=query(rt[x]); } int main() { n=read(); for(int i=1;i<n;i++){ int x=read(),y=read(); ins(x,y);ins(y,x); } d[1]=1;dfs1(1);m=read(); int now=1; for(int i=1;i<=m;i++){ int x=read(),y=read(); scanf(" %s",c+1); be[i]=now;int l=strlen(c+1); for(int i=1;i<=l;i++)s[now++]=c[i]; int lca=Lca(x,y); if(d[x]>d[y])swap(x,y); if(x==lca)v[fa[x][0]].pb((node){i,-1}),v[y].pb((node){i,1}); else v[x].pb((node){i,1}),v[y].pb((node){i,1}),v[lca].pb((node){i,-1}),v[fa[lca][0]].pb((node){i,-1}); } be[m+1]=now; cnt=0;dfs(1,0); for(int i=1;i<=n;i++)printf("%d\n",res[i]); return 0; }