【洛谷P1471】方差
题目大意:维护一个有 N 个元素的序列,支持以下操作:区间加,区间询问均值,区间询问方差。
题解:可知区间均值和区间和有关,即:维护区间和就等于维护了区间均值。区间方差表达式为 $$\frac{\Sigma_{i=1}n(a[i]-aver)2}{n}$$,化简之后可知还需维护区间的平方和。
这道题说明了,对于线段树来说,维护的东西并不一定直接是需要维护的东西,可以维护一些间接的信息,最后综合到一起计算得到需要维护的答案。
代码如下
#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
int n,q;
double a[maxn];
struct node{int lc,rc;double tag,sum,sum2;};
struct segment_tree{
#define ls t[k].lc
#define rs t[k].rc
node t[maxn<<1];
int tot;
segment_tree():tot(1){memset(t,0,sizeof(t));}
inline void pushup(int k){
t[k].sum=t[ls].sum+t[rs].sum;
t[k].sum2=t[ls].sum2+t[rs].sum2;
}
inline void pushdown(int k,int l,int r){
int mid=l+r>>1;
t[ls].sum2+=(mid-l+1)*t[k].tag*t[k].tag+2*t[k].tag*t[ls].sum;
t[ls].sum+=(mid-l+1)*t[k].tag;
t[ls].tag+=t[k].tag;
t[rs].sum2+=(r-mid)*t[k].tag*t[k].tag+2*t[k].tag*t[rs].sum;
t[rs].sum+=(r-mid)*t[k].tag;
t[rs].tag+=t[k].tag;
t[k].tag=0;
}
void build(int k,int l,int r){
if(l==r){t[k].sum=a[l],t[k].sum2=a[l]*a[l];return;}
int mid=l+r>>1;
ls=++tot,build(ls,l,mid);
rs=++tot,build(rs,mid+1,r);
pushup(k);
}
void modify(int k,int l,int r,int x,int y,double val){
if(l==x&&r==y){
t[k].sum2+=2*val*t[k].sum+(r-l+1)*val*val;
t[k].sum+=(r-l+1)*val;
t[k].tag+=val;
return;
}
int mid=l+r>>1;
pushdown(k,l,r);
if(y<=mid)modify(ls,l,mid,x,y,val);
else if(x>mid)modify(rs,mid+1,r,x,y,val);
else modify(ls,l,mid,x,mid,val),modify(rs,mid+1,r,mid+1,y,val);
pushup(k);
}
double query1(int k,int l,int r,int x,int y){
if(l==x&&r==y)return t[k].sum;
int mid=l+r>>1;
pushdown(k,l,r);
if(y<=mid)return query1(ls,l,mid,x,y);
else if(x>mid)return query1(rs,mid+1,r,x,y);
else return query1(ls,l,mid,x,mid)+query1(rs,mid+1,r,mid+1,y);
}
double query2(int k,int l,int r,int x,int y){
if(l==x&&r==y)return t[k].sum2;
int mid=l+r>>1;
pushdown(k,l,r);
if(y<=mid)return query2(ls,l,mid,x,y);
else if(x>mid)return query2(rs,mid+1,r,x,y);
else return query2(ls,l,mid,x,mid)+query2(rs,mid+1,r,mid+1,y);
}
double mean(int l,int r){return this->query1(1,1,n,l,r)/(r-l+1);}
double var(int l,int r){
double tmp=this->mean(l,r);
double tmp2=this->query2(1,1,n,l,r);
return tmp2/(r-l+1)-tmp*tmp;
}
}sgt;
void read_and_parse(){
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)scanf("%lf",&a[i]);
sgt.build(1,1,n);
}
void solve(){
int opt,x,y;
double k;
while(q--){
scanf("%d",&opt);
if(opt==1){
scanf("%d%d%lf",&x,&y,&k);
sgt.modify(1,1,n,x,y,k);
}else if(opt==2){
scanf("%d%d",&x,&y);
printf("%.4lf\n",sgt.mean(x,y));
}else if(opt==3){
scanf("%d%d",&x,&y);
printf("%.4lf\n",sgt.var(x,y));
}
}
}
int main(){
read_and_parse();
solve();
return 0;
}