P3707 题解
前言
思路和码量都有难度,思路难点是推式子,代码难度是线段树操作很奇怪,导致很难打。
思路
转化式子
首先将 \(a\) 爆拆,目标就是把平均值搞掉。(有点恐怖,建议仔细看;看不懂没关系,记住最后化简的式子就行)
\(\begin{aligned} a & = \dfrac{\sum_{i = L}^R (x_i - \overline{x})\cdot(y_i - \overline{y})}{\sum_{i = L}^R (x_i - \overline{x})^2} \\ & = \dfrac{\sum_{i = L}^R (x_i\cdot y_i - x_i\cdot \overline{y} - y_i\cdot \overline{x} + \overline{x} \cdot \overline{y})}{\sum_{i = L}^R x_i^2 - 2x_i\overline{x} + \overline{x} ^ 2}\\ & = \dfrac{(\sum x_iy_i) - (\overline{y} \cdot \sum x_i) - (\overline{x} \cdot \sum y_i) + \Big((R - L + 1) \cdot \overline{x} \cdot \overline{y}\Big)}{(\sum x_i^2) - (2\overline{x}\sum x_i) + \Big((R - L + 1)\overline{x}^2\Big)} \\ & = \dfrac{\sum x_iy_i - \frac{\sum x_i \sum y_i}{R - L + 1} - \frac{\sum x_i \sum y_i}{R - L + 1} + \frac{\sum x_i \sum y_i}{R - L + 1}}{\sum x_i^2 - 2 \cdot \frac{(\sum x_i) ^ 2}{R - L + 1} + \frac{(\sum x_i) ^ 2}{R - L + 1}} \\ & = \dfrac{\sum x_iy_i - \frac{\sum x_i \sum y_i}{R - L + 1}}{\sum x_i^2 - \frac{(\sum x_i) ^ 2}{R - L + 1}}\end{aligned}\)
至此,我们搞掉了平均数。观察最终式子,只需维护 \(\sum x_i\)、\(\sum y_i\)、\(\sum x_iy_i\)、\(\sum x_i^2\) 的值。
容易发现题目是区间求、区间改。所以我们用线段树处理。
利用线段树知道这四个值后,直接代入式子,就能计算出答案了。
线段树实现
梳理一下更新操作:
- 区间 \(x_i\) 与 \(y_i\) 加上一个数。
- 区间修改 \(x_i = y_i = i\),然后再来一次操作一。
为了方便,我们可以用一个结构体封装线段树的存储。
对于两个点合并,发现可以直接加起来。为了看起来更酷,我重载了一下加号。
//db 是 double 的意思,为什么用 double 后面会解释
struct Node
{
db x, y, xx, xy, s, t;
bool cov;
Node() {x = y = xx = xy = s = t = cov = 0;} //初始化很重要
} tr[N << 2];
Node operator +(Node a, Node b)
{
Node ans;
ans.x = a.x + b.x;
ans.y = a.y + b.y;
ans.xx = a.xx + b.xx;
ans.xy = a.xy + b.xy;
return ans;
}
那么 pushup
就很简单了。
void pushup(int pos) {tr[pos] = tr[ls(pos)] + tr[rs(pos)];}
考虑线段树重要操作 pushdown
:貌似难很多。推一推式子。
-
\(\sum (x_i + s) = \sum x_i + s \cdot (R - L + 1)\)
-
\(\sum (y_i + t) = \sum y_i + t \cdot (R - L + 1)\)
-
\(\sum (x_i + s)(y_i + t) = \sum x_iy_i + s \sum y_i + t \sum x_i + s \cdot t \cdot (R - L + 1)\)
-
\(\sum (x_i + s)^2 = \sum (x_i^2 + 2 \cdot s \cdot x_i + s^2) = \sum x_i^2 + 2 \cdot s \cdot \sum x_i + s^2 \cdot (R - L + 1)\)
综上,我们可以写出 lazy
函数。
void lazy(int l, int r, int pos, db S, db T)
{
tr[pos].s += S, tr[pos].t += T;
tr[pos].xy += T * tr[pos].x + S * tr[pos].y + S * T * (r - l + 1); //对比原xy暴力展开即可
tr[pos].xx += 2 * S * tr[pos].x + S * S * (r - l + 1); //对比原xx暴力展开即可
tr[pos].x += S * (r - l + 1), tr[pos].y += T * (r - l + 1);
}
不幸的是,我们还要再处理一个东西:区间覆盖 \(x_i = y_i = i\)。
也就是说,我们还需要一个操作:区间重建 \(x_i = y_i = i\)。
前置知识:\(1^2 + 2^2 + 3^2 + \cdots + n^2 = \dfrac{n \cdot (n + 1) \cdot (2n + 1)}{6}\)。不会可以去百度查。
db sqsum(db x) {return x * (x + 1) * (2 * x + 1) / 6;} //1*1 + 2*2 + ... + x*x
db sqsum(int l, int r) {return sqsum(r) - sqsum(l - 1);} //l*l + ... + r*r
void rebuild(db l, db r, int pos) //重建 xi = yi = i
{
tr[pos].s = tr[pos].t = 0, tr[pos].cov = true;
tr[pos].x = tr[pos].y = (l + r) * (r - l + 1) / 2; //等差数列求和
tr[pos].xx = tr[pos].xy = sqsum(l, r);
}
那么,我们终于可以写 pushdown
了。
void pushdown(int l, int r, int pos)
{
int mid = (l + r) >> 1;
if (tr[pos].cov) //如果需要覆盖,那就下传覆盖
{
rebuild(l, mid, ls(pos));
rebuild(mid + 1, r, rs(pos));
tr[pos].cov = false;
}
lazy(l, mid, ls(pos), tr[pos].s, tr[pos].t);
lazy(mid + 1, r, rs(pos), tr[pos].s, tr[pos].t);
tr[pos].s = tr[pos].t = 0;
}
其实到这里就马上搞定了,把其他几个基本操作码掉就完事了。
完整代码
不幸的是,你会发现数据较大,long long
都有机会炸。
你可以直接使用 double
。
代码个人认为挺好看的。
#include <iostream>
#include <cstdio>
#include <cstring>
#define space putchar(' ')
#define endl putchar('\n')
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
typedef double db;
typedef long double LD;
namespace io { //fast read && write 可以直接不管
bool FREAD = false; //do you need to use fread?
void fastio()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
}
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline char getc()
{
if (!FREAD) return getchar();
if (p1 == p2) p2 = buf + fread(buf, 1, 1 << 21, stdin), p1 = buf;
return *(p1++);
}
inline int read()
{
char op = getc(); int x = 0, f = 1;
while (op < 48 || op > 57) {if (op == '-') f = -1; op = getc();}
while (48 <= op && op <= 57) x = (x << 1) + (x << 3) + (op ^ 0x30), op = getc();
return x * f;
}
inline void write(int x)
{
if (x < 0) putchar('-'), x = -x;
if (x > 9) write(x / 10);
putchar(x % 10 + 0x30);
}
} using namespace io;
const int N = 1e5 + 5;
db kx[N], ky[N];
struct Node
{
db x, y, xx, xy, s, t;
bool cov;
Node() {x = y = xx = xy = s = t = cov = 0;}
} tr[N << 2];
Node operator +(Node a, Node b)
{
Node ans;
ans.x = a.x + b.x;
ans.y = a.y + b.y;
ans.xx = a.xx + b.xx;
ans.xy = a.xy + b.xy;
return ans;
}
struct SegmentTree
{
int ls(int x) {return x << 1;}
int rs(int x) {return x << 1 | 1;}
void pushup(int pos) {tr[pos] = tr[ls(pos)] + tr[rs(pos)];}
void build(int l, int r, int pos)
{
if (l == r)
{
tr[pos].x = kx[l], tr[pos].y = ky[l];
tr[pos].xx = kx[l] * kx[l];
tr[pos].xy = kx[l] * ky[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls(pos));
build(mid + 1, r, rs(pos));
pushup(pos);
}
db sqsum(db x) {return x * (x + 1) * (2 * x + 1) / 6;} //1*1 + 2*2 + ... + x*x
db sqsum(int l, int r) {return sqsum(r) - sqsum(l - 1);} //l*l + ... + r*r
void rebuild(db l, db r, int pos) //重建 xi = yi = i
{
tr[pos].s = tr[pos].t = 0, tr[pos].cov = true;
tr[pos].x = tr[pos].y = (l + r) * (r - l + 1) / 2; //等差数列求和
tr[pos].xx = tr[pos].xy = sqsum(l, r);
}
void lazy(int l, int r, int pos, db S, db T)
{
tr[pos].s += S, tr[pos].t += T;
tr[pos].xy += T * tr[pos].x + S * tr[pos].y + S * T * (r - l + 1); //对比原xy暴力展开即可
tr[pos].xx += 2 * S * tr[pos].x + S * S * (r - l + 1); //对比原xx暴力展开即可
tr[pos].x += S * (r - l + 1), tr[pos].y += T * (r - l + 1);
}
void pushdown(int l, int r, int pos)
{
int mid = (l + r) >> 1;
if (tr[pos].cov)
{
rebuild(l, mid, ls(pos));
rebuild(mid + 1, r, rs(pos));
tr[pos].cov = false;
}
lazy(l, mid, ls(pos), tr[pos].s, tr[pos].t);
lazy(mid + 1, r, rs(pos), tr[pos].s, tr[pos].t);
tr[pos].s = tr[pos].t = 0;
}
void update(int L, int R, int l, int r, int pos, db S, db T)
{
if (L <= l && r <= R) return lazy(l, r, pos, S, T);
pushdown(l, r, pos);
int mid = (l + r) >> 1;
if (L <= mid) update(L, R, l, mid, ls(pos), S, T);
if (mid < R) update(L, R, mid + 1, r, rs(pos), S, T);
pushup(pos);
}
void modify(int L, int R, int l, int r, int pos, db S, db T)
{
if (L <= l && r <= R)
{
rebuild(l, r, pos), lazy(l, r, pos, S, T);
return;
}
pushdown(l, r, pos);
int mid = (l + r) >> 1;
if (L <= mid) modify(L, R, l, mid, ls(pos), S, T);
if (mid < R) modify(L, R, mid + 1, r, rs(pos), S, T);
pushup(pos);
}
Node query(int L, int R, int l, int r, int pos)
{
if (L <= l && r <= R) return tr[pos];
pushdown(l, r, pos);
int mid = (l + r) >> 1;
Node ans;
if (L <= mid) ans = ans + query(L, R, l, mid, ls(pos));
if (mid < R) ans = ans + query(L, R, mid + 1, r, rs(pos));
return ans;
}
} seg;
int main()
{
int n = read(), T = read();
for (int i = 1; i <= n; i++) kx[i] = read();
for (int i = 1; i <= n; i++) ky[i] = read();
seg.build(1, n, 1);
while (T--)
{
int op = read(), l = read(), r = read(), s = (op == 1 ? -1 : read()), t = (op == 1 ? -1 : read()); //压行
if (op == 1)
{
Node ans = seg.query(l, r, 1, n, 1); //抄公式
db x = ans.x, y = ans.y, xx = ans.xx, xy = ans.xy;
db fenzi = xy - x * y / (r - l + 1), fenmu = xx - x * x / (r - l + 1);
printf("%.10lf\n", fenzi / fenmu);
}
else if (op == 2) seg.update(l, r, 1, n, 1, s, t);
else if (op == 3) seg.modify(l, r, 1, n, 1, s, t);
}
return 0;
}
码字不易(特别是开头的一长串 \(\LaTeX\)),希望能帮助到大家!