[loj3312]传统艺能

定义【被修改】表示在$[l,r]\subseteq [q_{l},q_{r}]$且$[l_{fa},r_{fa}]\nsubseteq  [q_{l},q_{r}]$,【被经过】表示$[l,r]\nsubseteq [q_{l},q_{r}]$且$[l,r]\cap [q_{l},q_{r}]\neq\empty$
将区间贡献分开来统计,状态可以用$自己是否为祖先中是否存在(自己是否为1,祖先中是否存在1)$来描述,记$P(k,p_{1},p_{2})$表示点$k$状态为$(p_{1},p_{2})$的概率,考虑区间操作对状态的影响($p$表示对应区间的概率):
1.被修改,转移为$P'(k,1,0)+=p$
2.被经过,转移为$P'(k,0,0)+=p$
3.祖先被修改,转移为$P'(k,p_{1},1)+=p\cdot(P(k,p_{1},0)+P(k,p_{1},1))$
4.父亲被经过且自己未被经过,转移为$P'(k,0,0)+=p\cdot P(k,0,0)$,$P'(k,1,0)+=p\cdot (1-P(k,0,0))$
5.不属于以上任何一种,转移为$P'(k,p_{1},p_{2})+=p\cdot P(k,p_{1},p_{2})$
特别的,对于$[1,n]$,其有贡献当且仅当最后一次覆盖的区间为$[1,n]$,概率为$\frac{2}{n(n+1)}$
最终答案即为$P(k,1,0)+P(k,1,1)$即为这个区间的贡献,转移用矩阵乘法来维护即可
 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 #define N 400005
 4 #define mod 998244353
 5 struct ji{
 6     int l,r;
 7 }seg[N];
 8 struct mat{
 9     int a[4][4];
10 }e,o;
11 int V,r,n,m,all,ans,fa[N],ls[N],rs[N],mid[N];
12 int ksm(int n,int m){
13     if (!m)return 1;
14     int s=ksm(n,m>>1);
15     s=1LL*s*s%mod;
16     if (m&1)s=1LL*s*n%mod;
17     return s;
18 }
19 mat mul(mat x,mat y){
20     mat o;
21     memset(o.a,0,sizeof(o.a));
22     for(int i=0;i<4;i++)
23         for(int j=0;j<4;j++)
24             for(int k=0;k<4;k++)
25                 o.a[i][j]=(o.a[i][j]+1LL*x.a[i][k]*y.a[k][j])%mod;
26     return o;
27 }
28 mat ksm(mat n,int m){
29     if (!m)return e;
30     mat o=ksm(n,m>>1);
31     o=mul(o,o);
32     if (m&1)o=mul(o,n);
33     return o;
34 }
35 void build(int &k,int l,int r){
36     if (!k)seg[k=++V]=ji{l,r};
37     if (l==r)return;
38     scanf("%d",&mid[k]);
39     build(ls[k],l,mid[k]);
40     build(rs[k],mid[k]+1,r);
41     fa[ls[k]]=fa[rs[k]]=k;
42 }
43 int calc1(int k){
44     return (1LL*(seg[k].l-seg[fa[k]].l)*(n-seg[k].r+1)+1LL*seg[k].l*(seg[fa[k]].r-seg[k].r))%mod*all%mod;
45 }
46 int calc2(int k){
47     return (1LL*n*(seg[k].r-seg[k].l)-1LL*(seg[k].r-seg[k].l)*(seg[k].r-seg[k].l-1)/2)%mod*all%mod;
48 }
49 int calc3(int k){
50     return 1LL*seg[fa[k]].l*(n-seg[fa[k]].r+1)%mod*all%mod;
51 }
52 int calc4(int k){
53     return (1LL*(seg[k].l+seg[fa[k]].l-1)*(seg[k].l-seg[fa[k]].l)/2+1LL*(n+n-seg[k].r-seg[fa[k]].r+1)*(seg[fa[k]].r-seg[k].r)/2)%mod*all%mod;
54 }
55 int main(){
56     scanf("%d%d",&n,&m);
57     build(r,1,n);
58     all=ans=ksm((n+1LL)*n/2%mod,mod-2);
59     e.a[0][0]=e.a[1][1]=e.a[2][2]=e.a[3][3]=1;
60     for(int i=2;i<=V;i++){
61         int p1,p2,p3,p4,p5;
62         p1=calc1(i);
63         p2=calc2(i);
64         p3=calc3(i);
65         p4=calc4(i);
66         p5=mod+1-((p1+p2)%mod+(p3+p4)%mod)%mod;
67         o.a[0][0]=((p2+p4)%mod+p5)%mod;
68         o.a[1][0]=o.a[2][0]=o.a[3][0]=p2;
69         o.a[0][1]=p3;
70         o.a[1][1]=(p3+p5)%mod;
71         o.a[2][1]=o.a[3][1]=0;
72         o.a[0][2]=p1;
73         o.a[1][2]=o.a[3][2]=(p1+p4)%mod;
74         o.a[2][2]=((p1+p4)%mod+p5)%mod;
75         o.a[0][3]=o.a[1][3]=0;
76         o.a[2][3]=p3;
77         o.a[3][3]=(p3+p5)%mod;
78         o=ksm(o,m);
79         ans=(ans+(o.a[0][2]+o.a[0][3])%mod)%mod;
80     }
81     printf("%d",ans);
82 }
View Code

 

posted @ 2020-08-10 15:21  PYWBKTDA  阅读(187)  评论(0编辑  收藏  举报