21.7.15
tag:线段树,概率期望,矩阵乘法
\[E'(x)=pE(ax+b)+(1-p)E(x)
\]
\[E'(x^2)=pE((ax+b)^2)+(1-p)E(x^2)
\]
\[E'(x)=(1-p+pa)E(x)+pb
\]
\[E'(x^2)=(1-p+pa^2)E(x^2)+2pabE(x)+pb^2
\]
然后线段树维护区间矩阵乘法就行了。
#include<bits/stdc++.h>
using namespace std;
template<typename T>
inline void Read(T &n){
char ch; bool flag=false;
while(!isdigit(ch=getchar()))if(ch=='-')flag=true;
for(n=ch^48;isdigit(ch=getchar());n=(n<<1)+(n<<3)+(ch^48));
if(flag)n=-n;
}
enum{
MAXN = 50005,
MOD = 998244353
};
inline int ksm(int base, int k=MOD-2){
int res=1;
while(k){
if(k&1)
res = 1ll*res*base%MOD;
base = 1ll*base*base%MOD;
k >>= 1;
}
return res;
}
inline int dec(int a, int b){
a -= b;
if(a<0) a += MOD;
return a;
}
inline int inc(int a, int b){
a += b;
if(a>=MOD) a -= MOD;
return a;
}
inline void ddec(int a, int b){a = dec(a,b);}
inline void upd(int &a, long long b){a = (a+b)%MOD;}
int p[MAXN], a[MAXN], b[MAXN], n, m;
inline int sqr(int x){return 1ll*x*x%MOD;}
int ex, ex2;
inline void solve(int l, int r, int x){
ex = x; ex2 = 1ll*x*x%MOD;
for(int i=l; i<=r; i++)
ex2 = (1ll*(1+1ll*p[i]*sqr(a[i])-p[i]+MOD)%MOD*ex2+2ll*p[i]*a[i]%MOD*b[i]%MOD*ex+1ll*p[i]*sqr(b[i]))%MOD,
ex = (1ll*(1+1ll*p[i]*a[i]-p[i]+MOD)%MOD*ex+1ll*p[i]*b[i])%MOD;
}
struct Matrix{
int a[3][3], n, m;
Matrix(){n=3,m=3;memset(a,0,sizeof a);}
Matrix(int N, int M){n=N,m=M;memset(a,0,sizeof a);}
inline Matrix operator *(const Matrix &k){
Matrix res(n,k.m);
for(int i=0; i<n; i++)
for(int j=0; j<k.m; j++)
for(int p=0; p<m; p++)
upd(res.a[i][j],1ll*a[i][p]*k.a[p][j]);
return res;
}
}t[MAXN<<2], f(1,3);
inline int lc(int x){return x<<1;}
inline int rc(int x){return x<<1|1;}
void Update(int x, int head, int tail, int pos){
if(head==tail){
t[x].a[0][0] = (1+1ll*p[pos]*sqr(a[pos])-p[pos]+MOD)%MOD;
t[x].a[1][0] = 2ll*p[pos]*a[pos]%MOD*b[pos]%MOD;
t[x].a[1][1] = (1+1ll*p[pos]*a[pos]-p[pos]+MOD)%MOD;
t[x].a[2][0] = 1ll*p[pos]*sqr(b[pos])%MOD;
t[x].a[2][1] = 1ll*p[pos]*b[pos]%MOD;
t[x].a[2][2] = 1;
return;
}
int mid = head+tail >> 1;
if(pos<=mid) Update(lc(x),head,mid,pos);
if(mid<pos) Update(rc(x),mid+1,tail,pos);
t[x] = t[lc(x)]*t[rc(x)];
}
void Query(int x, int head, int tail, int l, int r){
if(l<=head and tail<=r) return f = f*t[x], void();
int mid = head+tail >> 1;
if(l<=mid) Query(lc(x),head,mid,l,r);
if(mid<r) Query(rc(x),mid+1,tail,l,r);
}
int main(){
Read(n); Read(m);
for(int i=1; i<=n; i++) Read(p[i]), Read(a[i]), Read(b[i]), Update(1,1,n,i);
while(m--){
int opt;
Read(opt);
if(opt==0){
int x; Read(x);
Read(p[x]), Read(a[x]), Read(b[x]);
Update(1,1,n,x);
}
else{
int l, r, x;
Read(l); Read(r); Read(x);
f.a[0][0] = sqr(x);
f.a[0][1] = x;
f.a[0][2] = 1;
Query(1,1,n,l,r);
ex2 = f.a[0][0];
ex = f.a[0][1];
if(opt==1) printf("%d\n",ex);
if(opt==2) printf("%lld\n",(ex2-1ll*ex*ex%MOD+MOD)%MOD);
}
}
return 0;
}
/*
5 1
499122177 1 1
499122177 2 0
499122177 1 0
499122177 1 1
499122177 2 0
2 1 5 1
*/