Luogu4927 梦美与线段树 - 线段树 -
题目链接:https://www.luogu.com.cn/problem/P4927
题解:
《星之梦》真的不错,key顶尖短篇,推荐。
首先看一下期望是什么:
从期望的定义出发,比如先在根节点尝试 \(sum[1..n]\) 次,那么根据概率得有 \(sum[1..mid]\) 个情况跑到左子树中,\(sum[mid+1..n]\) 跑到右子树中,也就是说这一次对于期望的贡献就是 \(\frac{sum[1..mid]^2+sum[mid+1..n]^2}{sum[1..n]}\)(平方是因为尝试了这么多次,并且这一次的权值也是这个值。)
递归下去我们就发现最后的期望值就是 \(\frac{A}{sum[1..n]}\),其中 \(A\) 就是对于线段树上每一个点 \(i\),这个点代表的区间 \([l..r]\) 的权值和 \(A_i\),求这个和的平方再对所有点求和得到(\(A=\sum A_i^2\))
现在要带修改,考虑怎么区间修改的同时维护 \(A\)。
考虑对于线段树一个点 \(i\) ,如果这个点对应的区间 \([l..r]\) 中所有数都 \(+v\),那么 \(A\) 如何变化(考虑这个是因为我们要打懒标记,因此必须预先知道这个区间的变化)?
\(\Delta A_i=(sum[l..r]+len\times v)^2-sum[l..r]^2=2\times sum[l..r]\)\(\times v \times len +\ len^2\times v^2\)
因此只需要维护 \(\sum sum[l..r]\times len\) 和 \(\sum len^2\) 即可,其中后者是个常量,可以在建树的时候处理出来。
考虑前者该如何维护,\((sum[l..r]+v\times len)\times len = sum[l..r]\times len+len^2\times v\),因此只需要借助 \(\sum len^2\) 即可(注意这里只是对 \(i\) 结点来说的,但是我们需要的是这个点及子树的和,因此我们需要 \(\sum len^2\),也就是说我要对子树中所有点都加一次作为当前子树根节点的答案,因此要用到 \(\sum len^2\))
pushup 也很好写,按定义维护变量即可
线段树区间修改懒标记的本质就是能对一个大区间整体操作,而不需要借助小区间递推上来(如区间加和,可以 \(sum[l..r]+len\times v\) 而不必借助子树),在本题中也是一样,借助 \(\sum len^2\) 和 \(\sum sum[l..r]\times len\),我们就可以实现对区间的整体操作,其实这也是 pushdown 能进行的基础
代码:
// by SkyRainWind
#include <bits/stdc++.h>
#define mpr make_pair
#define debug() cerr<<"Yoshino\n"
#define pii pair<int,int>
#define pb push_back
using namespace std;
typedef long long ll;
typedef long long LL;
#define int __int128
const int inf = 1e9, INF = 0x3f3f3f3f, maxn = 2e5+5, mod = 998244353;
int n,m,a[maxn];
struct segm{
int l; // l = \sum len^2
int len; // 当前区间长度
int sum,val,lazy; // sum 区间和 , val \sum sum[l..r]*len 对子树的点求和
int res; // 当前点子树中的点对应的区间和的平方的和(期望的分子)
}se[maxn << 2];
// S^2 -> (S+len*v)^2-S^2 = 2*S*len*v + len^2*v^2,其中 S 为区间和
// 维护 \sum len^2 (常量) \sum S*len
int pw(int x,int y){
if(!y)return 1;if(y==1)return x;
int mid=pw(x,y>>1);
if(y&1)return 1ll*mid*mid%mod*x%mod;
return 1ll*mid*mid%mod;
}
int read(){
char ch=getchar();int h=0,t=1;
while((ch>'9'||ch<'0')&&ch!='-') ch=getchar();
if(ch=='-') t=-1,ch=getchar();
while(ch>='0'&&ch<='9') h=h*10+ch-'0',ch=getchar();
return h*t;
}
void print(int x){if(x>9)print(x/10);putchar(x%10+'0');}
int sq(int x){return 1ll*x*x;}
void pushup(int num,int l,int r){
se[num].l = se[num<<1].l + se[num<<1|1].l + sq(r-l+1);
se[num].sum = se[num << 1].sum + se[num << 1|1].sum;
se[num].val = se[num << 1].val + se[num << 1|1].val + (r-l+1) * se[num].sum;
se[num].res = se[num << 1].res + se[num << 1|1].res + sq(se[num].sum);
}
void build(int x,int y,int num){
se[num].len = y-x+1;
if(x == y){
se[num].l = 1;
se[num].sum = se[num].val = a[x];
se[num].res = sq(a[x]);
return ;
}
int mid = x+y>>1;
build(x,mid,num <<1);build(mid+1,y,num<<1|1);
pushup(num,x,y);
}
void add(int num,int v){ // 对区间能进行整体操作
se[num].lazy += v;
se[num].res += 2*se[num].val*v + se[num].l*v*v;
se[num].sum += se[num].len * v;
se[num].val += se[num].l * v;
}
void pushdown(int num,int l,int r){
if(!se[num].lazy)return ;
add(num<<1,se[num].lazy);
add(num<<1|1,se[num].lazy);
se[num].lazy = 0;
}
void upd(int x,int y,int v,int l,int r,int num){
if(x<=l && r<=y){
add(num, v);
return ;
}
pushdown(num,l,r);
int mid=l+r>>1;
if(y<=mid)upd(x,y,v,l,mid,num<<1);
else if(x>mid)upd(x,y,v,mid+1,r,num<<1|1);
else upd(x,y,v,l,mid,num<<1), upd(x,y,v,mid+1,r,num<<1|1);
pushup(num,l,r);
}
signed main(){
n=read(),m=read();
for(int i=1;i<=n;i++)a[i]=read();
build(1,n,1);
while(m --){
int op=read(),l,r,v;
if(op == 2){
int fz = se[1].res / __gcd(se[1].res,se[1].sum);
int fm = se[1].sum / __gcd(se[1].res,se[1].sum);
print(fz%mod*pw(fm,mod-2)%mod);puts("");
}else{
l=read(),r=read(),v=read();
upd(l,r,v,1,n,1);
}
}
return 0;
}