21.7.15

tag:线段树,概率期望,矩阵乘法


\[E'(x)=pE(ax+b)+(1-p)E(x) \]

\[E'(x^2)=pE((ax+b)^2)+(1-p)E(x^2) \]

\[E'(x)=(1-p+pa)E(x)+pb \]

\[E'(x^2)=(1-p+pa^2)E(x^2)+2pabE(x)+pb^2 \]

然后线段树维护区间矩阵乘法就行了。


#include<bits/stdc++.h>
using namespace std;
 
template<typename T>
inline void Read(T &n){
    char ch; bool flag=false;
    while(!isdigit(ch=getchar()))if(ch=='-')flag=true;
    for(n=ch^48;isdigit(ch=getchar());n=(n<<1)+(n<<3)+(ch^48));
    if(flag)n=-n;
}
 
enum{
    MAXN = 50005,
    MOD = 998244353
};
 
inline int ksm(int base, int k=MOD-2){
    int res=1;
    while(k){
        if(k&1)
            res = 1ll*res*base%MOD;
        base = 1ll*base*base%MOD;
        k >>= 1;
    }
    return res;
}
 
inline int dec(int a, int b){
    a -= b;
    if(a<0) a += MOD;
    return a;
}
 
inline int inc(int a, int b){
    a += b;
    if(a>=MOD) a -= MOD;
    return a;
}
 
inline void ddec(int a, int b){a = dec(a,b);}
inline void upd(int &a, long long b){a = (a+b)%MOD;}
 
int p[MAXN], a[MAXN], b[MAXN], n, m;
 
inline int sqr(int x){return 1ll*x*x%MOD;}
 
int ex, ex2;
inline void solve(int l, int r, int x){
    ex = x; ex2 = 1ll*x*x%MOD;
    for(int i=l; i<=r; i++)
        ex2 = (1ll*(1+1ll*p[i]*sqr(a[i])-p[i]+MOD)%MOD*ex2+2ll*p[i]*a[i]%MOD*b[i]%MOD*ex+1ll*p[i]*sqr(b[i]))%MOD,
        ex = (1ll*(1+1ll*p[i]*a[i]-p[i]+MOD)%MOD*ex+1ll*p[i]*b[i])%MOD;
}
 
struct Matrix{
    int a[3][3], n, m;
    Matrix(){n=3,m=3;memset(a,0,sizeof a);}
    Matrix(int N, int M){n=N,m=M;memset(a,0,sizeof a);}
    inline Matrix operator *(const Matrix &k){
        Matrix res(n,k.m);
        for(int i=0; i<n; i++)
            for(int j=0; j<k.m; j++)
                for(int p=0; p<m; p++)
                    upd(res.a[i][j],1ll*a[i][p]*k.a[p][j]);
        return res;
    }
}t[MAXN<<2], f(1,3);
 
inline int lc(int x){return x<<1;}
inline int rc(int x){return x<<1|1;}
 
void Update(int x, int head, int tail, int pos){
    if(head==tail){
        t[x].a[0][0] = (1+1ll*p[pos]*sqr(a[pos])-p[pos]+MOD)%MOD;
        t[x].a[1][0] = 2ll*p[pos]*a[pos]%MOD*b[pos]%MOD;
        t[x].a[1][1] = (1+1ll*p[pos]*a[pos]-p[pos]+MOD)%MOD;
        t[x].a[2][0] = 1ll*p[pos]*sqr(b[pos])%MOD;
        t[x].a[2][1] = 1ll*p[pos]*b[pos]%MOD;
        t[x].a[2][2] = 1;
        return;
    }
    int mid = head+tail >> 1;
    if(pos<=mid) Update(lc(x),head,mid,pos);
    if(mid<pos) Update(rc(x),mid+1,tail,pos);
    t[x] = t[lc(x)]*t[rc(x)];
}
 
void Query(int x, int head, int tail, int l, int r){
    if(l<=head and tail<=r) return f = f*t[x], void();
    int mid = head+tail >> 1;
    if(l<=mid) Query(lc(x),head,mid,l,r);
    if(mid<r) Query(rc(x),mid+1,tail,l,r);
}
 
int main(){
    Read(n); Read(m);
    for(int i=1; i<=n; i++) Read(p[i]), Read(a[i]), Read(b[i]), Update(1,1,n,i);
    while(m--){
        int opt;
        Read(opt);
        if(opt==0){
            int x; Read(x);
            Read(p[x]), Read(a[x]), Read(b[x]);
            Update(1,1,n,x);
        }
        else{
            int l, r, x;
            Read(l); Read(r); Read(x);
            f.a[0][0] = sqr(x);
            f.a[0][1] = x;
            f.a[0][2] = 1;
            Query(1,1,n,l,r);
            ex2 = f.a[0][0];
            ex = f.a[0][1];
            if(opt==1) printf("%d\n",ex);
            if(opt==2) printf("%lld\n",(ex2-1ll*ex*ex%MOD+MOD)%MOD);
        }
    }
    return 0;
}
 
/*
5 1
499122177 1 1
499122177 2 0
499122177 1 0
499122177 1 1
499122177 2 0
2 1 5 1
*/
posted @ 2021-07-15 19:48  oisdoaiu  阅读(68)  评论(0编辑  收藏  举报