[BJOI2018] 链上二次求和 题解

Description

给你一根 \(n\) 个点的链,每个点有权值 \(a_i\) ,要求支持两种操作:

  • \(u,v\) 间的数加上 \(d\)
  • 询问链上所有长度在 \([l,r]\) 间的路径权值和。

\(n\le2\times10^5\)

Sol

考虑每个点对于答案产生的贡献,手玩一下可以得出下面的表(其中点 \((i,j)\) 表示点 \(j\) 在所有长度为 \(i\) 的路径上的出现次数)。

\[\begin{gathered} \begin{bmatrix} 1 & 1 & 1 & 1 & 1 & 1 & 1 \\ 1 & 2 & 2 & 2 & 2 & 2 & 1 \\1 & 2 & 3 & 3 & 3 & 2 & 1 \\1 & 2 & 3 & 4 & 3 & 2 & 1 \\1 & 2 & 3 & 3 & 3 & 2 & 1 \\1 & 2 & 2 & 2 & 2 & 2 & 1 \\1 & 1 & 1 & 1 & 1 & 1 & 1 \\\end{bmatrix} \quad \end{gathered} \]

观察到矩阵很有规律,我们考虑线段树维护所有长度为 \(i\) 的路径的总全职和,即线段树最底层一个点维护的是矩阵的一行的所有点的 \(a_i\times cnt_{i,j}\) ,直接维护因为每一行的 \(cnt\) 都不一样,我们考虑换一种方式:

我们假设每一条路径上结点的出现次数都是形如 \(1234321\) 类型的,那么我们的答案就是总贡献减去多出的贡献。

总贡献即 \(a_1+2a_2+3a_3+...+(k-1)a{k-1}+ka_k+(k-1)a_{k+1}+na_n\),拿线段树简单维护即可,考虑多出的贡献,假设我们查询长度在 \([2,5]\) 之间的路径权值和,那么多出来的部分即为:

那么我们多出来的贡献即为 \(a_2+4a_3+9a_4+4a_5+a_6\),正着维护 \(1^2a_1+2^2a_2+3^2a_3+...+n^2a_n\)\(n^2a_1+(n-1)^2a_2+(n-2)^2a_3+...+a_n\) ,查询时减去即可。

还有一种维护方式就是将上图继续分割,分成 \(4\) 个小块,如图所示:

分别计算每一块的贡献第一块为 \(1,3\),第二块为 \(1\),第三块为 \(1,3,6\),第四块为 \(1,3\)

那么我们维护 \(1a_1+3a_2+6a_3+10a_4+...\)\(1a_n+3a_{n-1}+6a_{n-2}+...\) 即可。

修改时维护上面的贡献即可。

时间复杂度 \(O(n\log n)\)

Code

代码实现有那么亿点点恶心...

#include<bits/stdc++.h>
#define int long long
#define Mod 1000000007
#define inv6 166666668
#define il inline
#define re register
using namespace std;
il int Read() {
    int x = 0, f = 1; char ch = getchar();
    while(!isdigit(ch)) {if(ch == '-')  f = -1; ch = getchar();}
    while(isdigit(ch)) {x = (x << 3) + (x << 1) + ch - '0'; ch = getchar();}
    return x * f;
}
struct node {
    int sz, sum, sum2, sum3, _sum2, _sum3, addv;
    node(int Sz = 0, int Sum = 0, int Sum2 = 0, int Sum3 = 0, int _Sum2 = 0, int _Sum3 = 0, int Addv = 0) {sz = Sz, _sum2 = _Sum2, _sum3 = _Sum3, sum = Sum, sum2 = Sum2, sum3 = Sum3, addv = Addv;}
}seg[800005];
il node Merge(node A, node B) {
    node C;
    C.sz = A.sz + B.sz;
    C.sum = (A.sum + B.sum) % Mod;
    C.sum2 = (A.sum2 + B.sum2 + A.sz * B.sum % Mod) % Mod;
    C.sum3 = (A.sum3 + B.sum3 + A.sz * B.sum2 % Mod + (A.sz + 1) * A.sz / 2 % Mod * B.sum % Mod) % Mod;
    C._sum2 = (B._sum2 + A._sum2 + B.sz * A.sum % Mod) % Mod;
    C._sum3 = (B._sum3 + A._sum3 + B.sz * A._sum2 % Mod + (B.sz + 1) * B.sz / 2 % Mod * A.sum % Mod) % Mod;
    return C;
}
int a[200005];
il void build(int o, int l, int r) {
    seg[o].sz = r - l + 1;
    if(l == r) {
        seg[o].sum = seg[o].sum2 = seg[o].sum3 = seg[o]._sum2 = seg[o]._sum3 = a[l];
        return ;
    }
    int mid = (l + r) >> 1;
    build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r);
    seg[o] = Merge(seg[o << 1], seg[o << 1 | 1]);
}
il void pusha(int o, int l, int r, int x) {
    int len = (r - l + 1);
    (seg[o].sum3 += (len * (len + 1) % Mod * (len + 2) % Mod * inv6 % Mod) * x % Mod) %= Mod;
    (seg[o]._sum3 += (len * (len + 1) % Mod * (len + 2) % Mod * inv6 % Mod) * x % Mod) %= Mod;
    (seg[o].sum2 += (len * (len + 1) / 2 % Mod) * x % Mod) %= Mod;
    (seg[o]._sum2 += (len * (len + 1) / 2 % Mod) * x % Mod) %= Mod;
    (seg[o].sum += len * x % Mod) %= Mod;
    seg[o].addv = (seg[o].addv + x) % Mod;
}
il void pushdown(int o, int l, int r) {
    int mid = (l + r) >> 1;
    if(seg[o].addv) {
        pusha(o << 1, l, mid, seg[o].addv);
        pusha(o << 1 | 1, mid + 1, r, seg[o].addv);
        seg[o].addv = 0;
    }
}
il void modify(int o, int l, int r, int nl, int nr, int x) {
    if(nl <= l && r <= nr)  return pusha(o, l, r, x);
    pushdown(o, l, r); int mid = (l + r) >> 1;
    if(nl <= mid)  modify(o << 1, l, mid, nl, nr, x);
    if(mid < nr)  modify(o << 1 | 1, mid + 1, r, nl, nr, x);
    seg[o] = Merge(seg[o << 1], seg[o << 1 | 1]);
}
il node query(int o, int l, int r, int nl, int nr) {
    if(nl > nr)  return (node){0ll, 0ll, 0ll, 0ll, 0ll, 0ll, 0ll};
    if(nl <= l && r <= nr)  return seg[o];
    pushdown(o, l, r); int mid = (l + r) >> 1;
    if(nl <= mid) {
        if(mid < nr)  return Merge(query(o << 1, l, mid, nl, nr), query(o << 1 | 1, mid + 1, r, nl, nr));
        return query(o << 1, l, mid, nl, nr);
    }
    return query(o << 1 | 1, mid + 1, r, nl, nr);
}
signed main() {
    int n = Read(), m = Read();
    for(re int i = 1; i <= n; i++)  a[i] = Read();
    build(1, 1, n);
    for(re int i = 1; i <= m; i++) {
        int opt = Read(), x = Read(), y = Read(), z;
        if(x > y)  swap(x, y);
        if(opt == 1) {
            z = Read();
            modify(1, 1, n, x, y, z);
        }
        else {
            if(x > y)  {puts("0"); continue;}
            if(n % 2 == 1) {
                int mid = (1 + n) >> 1; node A, B, C, D;
                A = query(1, 1, n, 1, mid); B = query(1, 1, n, mid + 1, n);
                int ans = (A.sum2 + B._sum2) % Mod * (y - x + 1) % Mod;
                if(y <= mid) {
                    if(y == mid)  --y;
                    A = query(1, 1, n, x + 1, mid); B = query(1, 1, n, mid + 1, 2 * mid - x - 1);
                    C = query(1, 1 ,n, y + 2, mid); D = query(1, 1, n, mid + 1, 2 * mid - y - 2);
                    ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                else if(x >= mid) {
                    if(x == mid)  ++x;
                    A = query(1, 1, n, 2 * mid - y + 1, mid); B = query(1, 1, n, mid + 1, y - 1);
                    C = query(1, 1, n, 2 * mid - x + 2, mid); D = query(1, 1, n, mid + 1, x - 2);
                    ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                else {
                    A = query(1, 1, n, x + 1, mid), B = query(1, 1, n, mid + 1, 2 * mid - x - 1);
                    C = query(1, 1, n, 2 * mid - y + 1, mid), D = query(1, 1, n, mid + 1, y - 1);
                    ans -= (A.sum3 + C.sum3 + B._sum3 + D._sum3) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                printf("%lld\n", ans);
            }
            if(n % 2 == 0) {
                int Lmid = (1 + n) >> 1, Rmid = Lmid + 1; node A, B, C, D;
                A = query(1, 1, n, 1, Lmid); B = query(1, 1, n, Rmid, n);
                int ans = (A.sum2 + B._sum2) % Mod * (y - x + 1) % Mod;
                if(y <= Rmid) {
                    if(y == Rmid)  --y; 
                    if(y == Lmid)  --y;
                    A = query(1, 1, n, x + 1, Lmid); B = query(1, 1, n, Rmid, 2 * Lmid - x);
                    C = query(1, 1, n, y + 2, Lmid); D = query(1, 1, n, Rmid, 2 * Lmid - y - 1);
                    ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                else if(x >= Lmid) {
                    if(x == Lmid)  ++x;
                    if(x == Rmid)  ++x;
                    A = query(1, 1, n, 2 * Lmid - y + 2, Lmid); B = query(1, 1, n, Rmid, y - 1);
                    C = query(1, 1, n, 2 * Lmid - x + 3, Lmid); D = query(1, 1, n, Rmid, x - 2);
                    ans -= (A.sum3 + B._sum3 - C.sum3 - D._sum3 + Mod) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                else {
                    A = query(1, 1, n, x + 1, Lmid), B = query(1, 1, n, Rmid, 2 * Lmid - x);
                    C = query(1, 1, n, 2 * Lmid - y + 2, Lmid), D = query(1, 1, n, Rmid, y - 1);
                    ans -= (A.sum3 + C.sum3 + B._sum3 + D._sum3) % Mod;
                    ans = (ans + Mod) % Mod;
                }
                printf("%lld\n", ans);
            }    
        }
    }
    return 0;
}
posted @ 2020-10-26 17:14  verjun  阅读(123)  评论(0编辑  收藏  举报