Luogu P1471 方差
题目
维护一个包含 \(n\) 个实数的数列,并有 \(m\) 个操作
操作为以下三种之一:
1 x y k
表示将第 \(x\) 到第 \(y\) 项每项加上 \(k\) ,\(k\) 为一实数2 x y
表示求出第 \(x\) 项到第 \(y\) 项这一子数列的平均数3 x y
表示求出第 \(x\) 项到第 \(y\) 项这一子数列的方差
分析
对于操作1和2,只需要用线段树维护数列的和即可
对于操作3,我们可以对方差公式进行变形:
\[\begin{aligned}
s^2&=\frac{1}{n}\sum_{i=1}^n(A_i-\overline A)^2\\
&=\frac{1}{n}(\sum_{i=1}^nA_i^2-2\overline A\sum_{i=1}^n A_i+n\overline A^2)\\
&=\frac{1}{n}\sum_{i=1}^n A_i^2-\overline A^2
\end{aligned}
\]
这样问题就转花了如何维护一个数列的平方和,只要在 update
函数上稍加修改即可:
void update(int k, int l, int r, double v)
{
sum2[k] += (r - l + 1) * v * v + 2 * v * sum1[k];
sum1[k] += (r - l + 1) * v;
add[k] += v;
}
代码
#include<bits/stdc++.h>
#define ls(k) k << 1
#define rs(k) k << 1 | 1
using namespace std;
const int MAX_N = 100000 + 5;
double a[MAX_N], sum1[MAX_N * 4], sum2[MAX_N * 4], add[MAX_N * 4];
void build(int k, int l, int r)
{
if(l == r) {
sum1[k] = a[l];
sum2[k] = a[l] * a[l];
return;
}
int m = (l + r) >> 1;
build(ls(k), l, m);
build(rs(k), m + 1, r);
sum1[k] = sum1[ls(k)] + sum1[rs(k)];
sum2[k] = sum2[ls(k)] + sum2[rs(k)];
}
void update(int k, int l, int r, double v)
{
sum2[k] += (r - l + 1) * v * v + 2 * v * sum1[k];
sum1[k] += (r - l + 1) * v;
add[k] += v;
}
void push_down(int k, int l, int r)
{
if(add[k] == 0)
return;
int m = (l + r) >> 1;
update(ls(k), l, m, add[k]);
update(rs(k), m + 1, r, add[k]);
add[k] = 0;
}
void modify(int k, int l, int r, int x, int y, double v)
{
if(l >= x && r <= y) {
update(k, l, r, v);
return;
}
push_down(k, l, r);
int m = (l + r) >> 1;
if(x <= m)
modify(ls(k), l, m, x, y, v);
if(m + 1 <= y)
modify(rs(k), m + 1, r, x, y, v);
sum1[k] = sum1[ls(k)] + sum1[rs(k)];
sum2[k] = sum2[ls(k)] + sum2[rs(k)];
}
double query1(int k, int l, int r, int x, int y)
{
if(l >= x && r <= y)
return sum1[k];
push_down(k, l, r);
int m = (l + r) >> 1;
double res = 0;
if(x <= m)
res += query1(ls(k), l, m, x, y);
if(m + 1 <= y)
res += query1(rs(k), m + 1, r, x, y);
return res;
}
double query2(int k, int l, int r, int x, int y)
{
if(l >= x && r <= y)
return sum2[k];
push_down(k, l, r);
int m = (l + r) >> 1;
double res = 0;
if(x <= m)
res += query2(ls(k), l, m, x, y);
if(m + 1 <= y)
res += query2(rs(k), m + 1, r, x, y);
return res;
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++)
scanf("%lf", &a[i]);
build(1, 1, n);
while(m--) {
int opt, x, y;
double k;
scanf("%d%d%d", &opt, &x, &y);
if(opt == 1) {
scanf("%lf", &k);
modify(1, 1, n, x, y, k);
} else if(opt == 2) {
printf("%.4lf\n", query1(1, 1, n, x, y) / (y - x + 1));
} else {
double t = query1(1, 1, n, x, y) / (y - x + 1);
printf("%.4lf\n", query2(1, 1, n, x, y) / (y - x + 1) - t * t);
}
}
return 0;
}