P3707 题解

前言

题目传送门!

更好的阅读体验?

思路和码量都有难度,思路难点是推式子,代码难度是线段树操作很奇怪,导致很难打。

思路

转化式子

首先将 \(a\) 爆拆,目标就是把平均值搞掉。(有点恐怖,建议仔细看;看不懂没关系,记住最后化简的式子就行)

\(\begin{aligned} a & = \dfrac{\sum_{i = L}^R (x_i - \overline{x})\cdot(y_i - \overline{y})}{\sum_{i = L}^R (x_i - \overline{x})^2} \\ & = \dfrac{\sum_{i = L}^R (x_i\cdot y_i - x_i\cdot \overline{y} - y_i\cdot \overline{x} + \overline{x} \cdot \overline{y})}{\sum_{i = L}^R x_i^2 - 2x_i\overline{x} + \overline{x} ^ 2}\\ & = \dfrac{(\sum x_iy_i) - (\overline{y} \cdot \sum x_i) - (\overline{x} \cdot \sum y_i) + \Big((R - L + 1) \cdot \overline{x} \cdot \overline{y}\Big)}{(\sum x_i^2) - (2\overline{x}\sum x_i) + \Big((R - L + 1)\overline{x}^2\Big)} \\ & = \dfrac{\sum x_iy_i - \frac{\sum x_i \sum y_i}{R - L + 1} - \frac{\sum x_i \sum y_i}{R - L + 1} + \frac{\sum x_i \sum y_i}{R - L + 1}}{\sum x_i^2 - 2 \cdot \frac{(\sum x_i) ^ 2}{R - L + 1} + \frac{(\sum x_i) ^ 2}{R - L + 1}} \\ & = \dfrac{\sum x_iy_i - \frac{\sum x_i \sum y_i}{R - L + 1}}{\sum x_i^2 - \frac{(\sum x_i) ^ 2}{R - L + 1}}\end{aligned}\)

至此,我们搞掉了平均数。观察最终式子,只需维护 \(\sum x_i\)\(\sum y_i\)\(\sum x_iy_i\)\(\sum x_i^2\) 的值。

容易发现题目是区间求、区间改。所以我们用线段树处理。

利用线段树知道这四个值后,直接代入式子,就能计算出答案了。

线段树实现

梳理一下更新操作:

  1. 区间 \(x_i\)\(y_i\) 加上一个数。
  2. 区间修改 \(x_i = y_i = i\),然后再来一次操作一。

为了方便,我们可以用一个结构体封装线段树的存储。

对于两个点合并,发现可以直接加起来。为了看起来更酷,我重载了一下加号。

//db 是 double 的意思,为什么用 double 后面会解释
struct Node
{
	db x, y, xx, xy, s, t;
	bool cov;
    Node() {x = y = xx = xy = s = t = cov = 0;} //初始化很重要
} tr[N << 2];
Node operator +(Node a, Node b)
{
	Node ans;
	ans.x = a.x + b.x;
	ans.y = a.y + b.y;
	ans.xx = a.xx + b.xx;
	ans.xy = a.xy + b.xy;
	return ans;
}

那么 pushup 就很简单了。

void pushup(int pos) {tr[pos] = tr[ls(pos)] + tr[rs(pos)];}

考虑线段树重要操作 pushdown:貌似难很多。推一推式子。

  1. \(\sum (x_i + s) = \sum x_i + s \cdot (R - L + 1)\)

  2. \(\sum (y_i + t) = \sum y_i + t \cdot (R - L + 1)\)

  3. \(\sum (x_i + s)(y_i + t) = \sum x_iy_i + s \sum y_i + t \sum x_i + s \cdot t \cdot (R - L + 1)\)

  4. \(\sum (x_i + s)^2 = \sum (x_i^2 + 2 \cdot s \cdot x_i + s^2) = \sum x_i^2 + 2 \cdot s \cdot \sum x_i + s^2 \cdot (R - L + 1)\)

综上,我们可以写出 lazy 函数。

void lazy(int l, int r, int pos, db S, db T)
{
    tr[pos].s += S, tr[pos].t += T;
    tr[pos].xy += T * tr[pos].x + S * tr[pos].y + S * T * (r - l + 1); //对比原xy暴力展开即可 
    tr[pos].xx += 2 * S * tr[pos].x + S * S * (r - l + 1); //对比原xx暴力展开即可 
    tr[pos].x += S * (r - l + 1), tr[pos].y += T * (r - l + 1); 
}

不幸的是,我们还要再处理一个东西:区间覆盖 \(x_i = y_i = i\)

也就是说,我们还需要一个操作:区间重建 \(x_i = y_i = i\)

前置知识:\(1^2 + 2^2 + 3^2 + \cdots + n^2 = \dfrac{n \cdot (n + 1) \cdot (2n + 1)}{6}\)。不会可以去百度查。

db sqsum(db x) {return x * (x + 1) * (2 * x + 1) / 6;} //1*1 + 2*2 + ... + x*x
db sqsum(int l, int r) {return sqsum(r) - sqsum(l - 1);} //l*l + ... + r*r
void rebuild(db l, db r, int pos) //重建 xi = yi = i
{
    tr[pos].s = tr[pos].t = 0, tr[pos].cov = true;
    tr[pos].x = tr[pos].y = (l + r) * (r - l + 1) / 2; //等差数列求和
    tr[pos].xx = tr[pos].xy = sqsum(l, r);
}

那么,我们终于可以写 pushdown 了。

void pushdown(int l, int r, int pos)
{
    int mid = (l + r) >> 1;
    if (tr[pos].cov) //如果需要覆盖,那就下传覆盖
    {
        rebuild(l, mid, ls(pos));
        rebuild(mid + 1, r, rs(pos));
        tr[pos].cov = false;
    }
    lazy(l, mid, ls(pos), tr[pos].s, tr[pos].t);
    lazy(mid + 1, r, rs(pos), tr[pos].s, tr[pos].t);
    tr[pos].s = tr[pos].t = 0;
}

其实到这里就马上搞定了,把其他几个基本操作码掉就完事了。

完整代码

不幸的是,你会发现数据较大,long long 都有机会炸。

你可以直接使用 double

代码个人认为挺好看的。

#include <iostream>
#include <cstdio>
#include <cstring>
#define space putchar(' ')
#define endl putchar('\n')
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
typedef double db;
typedef long double LD;
namespace io { //fast read && write 可以直接不管
    bool FREAD = false; //do you need to use fread?
    void fastio()
    {
        ios::sync_with_stdio(false);
        cin.tie(0), cout.tie(0);
    }
    char buf[1 << 21], *p1 = buf, *p2 = buf;
    inline char getc()
    {
        if (!FREAD) return getchar();
        if (p1 == p2) p2 = buf + fread(buf, 1, 1 << 21, stdin), p1 = buf;
        return *(p1++);
    }
    inline int read()
    {
        char op = getc(); int x = 0, f = 1;
        while (op < 48 || op > 57) {if (op == '-') f = -1; op = getc();}
        while (48 <= op && op <= 57) x = (x << 1) + (x << 3) + (op ^ 0x30), op = getc();
        return x * f;
    }
    inline void write(int x)
    {
        if (x < 0) putchar('-'), x = -x;
        if (x > 9) write(x / 10);
        putchar(x % 10 + 0x30);
    }
} using namespace io;
const int N = 1e5 + 5;
db kx[N], ky[N];
struct Node
{
	db x, y, xx, xy, s, t;
	bool cov;
    Node() {x = y = xx = xy = s = t = cov = 0;}
} tr[N << 2];
Node operator +(Node a, Node b)
{
	Node ans;
	ans.x = a.x + b.x;
	ans.y = a.y + b.y;
	ans.xx = a.xx + b.xx;
	ans.xy = a.xy + b.xy;
	return ans;
}
struct SegmentTree
{
    int ls(int x) {return x << 1;}
    int rs(int x) {return x << 1 | 1;}
    void pushup(int pos) {tr[pos] = tr[ls(pos)] + tr[rs(pos)];}
    void build(int l, int r, int pos)
    {
        if (l == r)
        {
            tr[pos].x = kx[l], tr[pos].y = ky[l];
            tr[pos].xx = kx[l] * kx[l];
            tr[pos].xy = kx[l] * ky[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(l, mid, ls(pos));
        build(mid + 1, r, rs(pos));
        pushup(pos);
    }
    db sqsum(db x) {return x * (x + 1) * (2 * x + 1) / 6;} //1*1 + 2*2 + ... + x*x
    db sqsum(int l, int r) {return sqsum(r) - sqsum(l - 1);} //l*l + ... + r*r
    void rebuild(db l, db r, int pos) //重建 xi = yi = i
    {
        tr[pos].s = tr[pos].t = 0, tr[pos].cov = true;
        tr[pos].x = tr[pos].y = (l + r) * (r - l + 1) / 2; //等差数列求和
        tr[pos].xx = tr[pos].xy = sqsum(l, r);
    }
    void lazy(int l, int r, int pos, db S, db T)
    {
        tr[pos].s += S, tr[pos].t += T;
        tr[pos].xy += T * tr[pos].x + S * tr[pos].y + S * T * (r - l + 1); //对比原xy暴力展开即可 
        tr[pos].xx += 2 * S * tr[pos].x + S * S * (r - l + 1); //对比原xx暴力展开即可 
        tr[pos].x += S * (r - l + 1), tr[pos].y += T * (r - l + 1); 
    }
    void pushdown(int l, int r, int pos)
    {
        int mid = (l + r) >> 1;
        if (tr[pos].cov)
        {
            rebuild(l, mid, ls(pos));
            rebuild(mid + 1, r, rs(pos));
            tr[pos].cov = false;
        }
        lazy(l, mid, ls(pos), tr[pos].s, tr[pos].t);
        lazy(mid + 1, r, rs(pos), tr[pos].s, tr[pos].t);
        tr[pos].s = tr[pos].t = 0;
    }
    void update(int L, int R, int l, int r, int pos, db S, db T)
    {
        if (L <= l && r <= R) return lazy(l, r, pos, S, T);
        pushdown(l, r, pos);
        int mid = (l + r) >> 1;
        if (L <= mid) update(L, R, l, mid, ls(pos), S, T);
        if (mid < R) update(L, R, mid + 1, r, rs(pos), S, T);
        pushup(pos);
    }
    void modify(int L, int R, int l, int r, int pos, db S, db T)
    {
        if (L <= l && r <= R)
        {
            rebuild(l, r, pos), lazy(l, r, pos, S, T);
            return;
        }
        pushdown(l, r, pos);
        int mid = (l + r) >> 1;
        if (L <= mid) modify(L, R, l, mid, ls(pos), S, T);
        if (mid < R) modify(L, R, mid + 1, r, rs(pos), S, T);
        pushup(pos);
    }
    Node query(int L, int R, int l, int r, int pos)
    {
        if (L <= l && r <= R) return tr[pos];
        pushdown(l, r, pos);
        int mid = (l + r) >> 1;
		Node ans;
		if (L <= mid) ans = ans + query(L, R, l, mid, ls(pos));
        if (mid < R) ans = ans + query(L, R, mid + 1, r, rs(pos));
        return ans;
    }
} seg;
int main()
{
    int n = read(), T = read();
    for (int i = 1; i <= n; i++) kx[i] = read();
    for (int i = 1; i <= n; i++) ky[i] = read();
    seg.build(1, n, 1);
    while (T--)
    {
        int op = read(), l = read(), r = read(), s = (op == 1 ? -1 : read()), t = (op == 1 ? -1 : read()); //压行 
        if (op == 1)
        {
            Node ans = seg.query(l, r, 1, n, 1); //抄公式
            db x = ans.x, y = ans.y, xx = ans.xx, xy = ans.xy;
            db fenzi = xy - x * y / (r - l + 1), fenmu = xx - x * x / (r - l + 1);
            printf("%.10lf\n", fenzi / fenmu);
        }
        else if (op == 2) seg.update(l, r, 1, n, 1, s, t);
        else if (op == 3) seg.modify(l, r, 1, n, 1, s, t);
    }
    return 0;
}

码字不易(特别是开头的一长串 \(\LaTeX\)),希望能帮助到大家!

posted @ 2022-10-09 20:57  liangbowen  阅读(25)  评论(0编辑  收藏  举报