[loj3312]传统艺能
定义【被修改】表示在$[l,r]\subseteq [q_{l},q_{r}]$且$[l_{fa},r_{fa}]\nsubseteq [q_{l},q_{r}]$,【被经过】表示$[l,r]\nsubseteq [q_{l},q_{r}]$且$[l,r]\cap [q_{l},q_{r}]\neq\empty$
将区间贡献分开来统计,状态可以用$自己是否为祖先中是否存在(自己是否为1,祖先中是否存在1)$来描述,记$P(k,p_{1},p_{2})$表示点$k$状态为$(p_{1},p_{2})$的概率,考虑区间操作对状态的影响($p$表示对应区间的概率):
1.被修改,转移为$P'(k,1,0)+=p$
2.被经过,转移为$P'(k,0,0)+=p$
3.祖先被修改,转移为$P'(k,p_{1},1)+=p\cdot(P(k,p_{1},0)+P(k,p_{1},1))$
4.父亲被经过且自己未被经过,转移为$P'(k,0,0)+=p\cdot P(k,0,0)$,$P'(k,1,0)+=p\cdot (1-P(k,0,0))$
5.不属于以上任何一种,转移为$P'(k,p_{1},p_{2})+=p\cdot P(k,p_{1},p_{2})$
特别的,对于$[1,n]$,其有贡献当且仅当最后一次覆盖的区间为$[1,n]$,概率为$\frac{2}{n(n+1)}$
最终答案即为$P(k,1,0)+P(k,1,1)$即为这个区间的贡献,转移用矩阵乘法来维护即可
View Code
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 400005 4 #define mod 998244353 5 struct ji{ 6 int l,r; 7 }seg[N]; 8 struct mat{ 9 int a[4][4]; 10 }e,o; 11 int V,r,n,m,all,ans,fa[N],ls[N],rs[N],mid[N]; 12 int ksm(int n,int m){ 13 if (!m)return 1; 14 int s=ksm(n,m>>1); 15 s=1LL*s*s%mod; 16 if (m&1)s=1LL*s*n%mod; 17 return s; 18 } 19 mat mul(mat x,mat y){ 20 mat o; 21 memset(o.a,0,sizeof(o.a)); 22 for(int i=0;i<4;i++) 23 for(int j=0;j<4;j++) 24 for(int k=0;k<4;k++) 25 o.a[i][j]=(o.a[i][j]+1LL*x.a[i][k]*y.a[k][j])%mod; 26 return o; 27 } 28 mat ksm(mat n,int m){ 29 if (!m)return e; 30 mat o=ksm(n,m>>1); 31 o=mul(o,o); 32 if (m&1)o=mul(o,n); 33 return o; 34 } 35 void build(int &k,int l,int r){ 36 if (!k)seg[k=++V]=ji{l,r}; 37 if (l==r)return; 38 scanf("%d",&mid[k]); 39 build(ls[k],l,mid[k]); 40 build(rs[k],mid[k]+1,r); 41 fa[ls[k]]=fa[rs[k]]=k; 42 } 43 int calc1(int k){ 44 return (1LL*(seg[k].l-seg[fa[k]].l)*(n-seg[k].r+1)+1LL*seg[k].l*(seg[fa[k]].r-seg[k].r))%mod*all%mod; 45 } 46 int calc2(int k){ 47 return (1LL*n*(seg[k].r-seg[k].l)-1LL*(seg[k].r-seg[k].l)*(seg[k].r-seg[k].l-1)/2)%mod*all%mod; 48 } 49 int calc3(int k){ 50 return 1LL*seg[fa[k]].l*(n-seg[fa[k]].r+1)%mod*all%mod; 51 } 52 int calc4(int k){ 53 return (1LL*(seg[k].l+seg[fa[k]].l-1)*(seg[k].l-seg[fa[k]].l)/2+1LL*(n+n-seg[k].r-seg[fa[k]].r+1)*(seg[fa[k]].r-seg[k].r)/2)%mod*all%mod; 54 } 55 int main(){ 56 scanf("%d%d",&n,&m); 57 build(r,1,n); 58 all=ans=ksm((n+1LL)*n/2%mod,mod-2); 59 e.a[0][0]=e.a[1][1]=e.a[2][2]=e.a[3][3]=1; 60 for(int i=2;i<=V;i++){ 61 int p1,p2,p3,p4,p5; 62 p1=calc1(i); 63 p2=calc2(i); 64 p3=calc3(i); 65 p4=calc4(i); 66 p5=mod+1-((p1+p2)%mod+(p3+p4)%mod)%mod; 67 o.a[0][0]=((p2+p4)%mod+p5)%mod; 68 o.a[1][0]=o.a[2][0]=o.a[3][0]=p2; 69 o.a[0][1]=p3; 70 o.a[1][1]=(p3+p5)%mod; 71 o.a[2][1]=o.a[3][1]=0; 72 o.a[0][2]=p1; 73 o.a[1][2]=o.a[3][2]=(p1+p4)%mod; 74 o.a[2][2]=((p1+p4)%mod+p5)%mod; 75 o.a[0][3]=o.a[1][3]=0; 76 o.a[2][3]=p3; 77 o.a[3][3]=(p3+p5)%mod; 78 o=ksm(o,m); 79 ans=(ans+(o.a[0][2]+o.a[0][3])%mod)%mod; 80 } 81 printf("%d",ans); 82 }