[lnsyoj2286/luoguP4458/BJOI2018]链上二次求和

题意

给定序列 \(a\),要求支持修改与查询操作:修改操作为对区间 \(l,r\) 的每个数 \(+d\),查询操作为给定区间 \(l,r\),要求查询:

\[\sum_{len=l}^r\sum_{i=l}^{n-len+1}\sum_{j=l}^{i+len-1}a_j \]

sol

化简式子(下设 \(sum_i=\sum_{j=1}^i a_j,ssum_i=\sum_{j=1}^i sum_j\)):

\[\begin{align} \nonumber &\sum_{len=l}^r\sum_{i=l}^{n-len+1}\sum_{j=i}^{i+len-1}a_j \\ \nonumber =&\sum_{len=l}^r\sum_{i=l}^{n-len+1}(sum_{i+len-1}-sum_{i-1}) \\ \nonumber =&\sum_{len=l}^r(\sum_{i=l}^{n-len+1}sum_{i+len-1})-(\sum_{i=l}^{n-len+1}sum_{i-1}) \\ \nonumber =&\sum_{len=l}^r (ssum_n-ssum_{len-1}-ssum_{n-len}) \\ \nonumber =&ssum_n \cdot (r - l + 1) - \sum_{len=l}^r ssum_{len-1} - \sum_{len=l}^r ssum_{n-len} \nonumber \end{align} \]

因此,我们只需要维护带有区间修改的前缀前缀和,并支持区间和查询即可。
易得,区间修改 \([l,r]\) 会对 \(sum_i\) 带来的贡献为

\[\Delta sum_i= \left \{ \begin{matrix} &0&(i<l) \\ &d \cdot (i - l + 1) &(l \le i \le r)\\ &d \cdot (r - l + 1) &(i > r) \end{matrix}\right. \]

从而可得,区间修改 \([l, r]\) 会对 \(ssum_i\) 带来的贡献为

\[\Delta ssum_i= \left \{ \begin{matrix} &0&(i<l) \\ &d \cdot \sum_{j=l}^i (j - l + 1) &(l \le i \le r)\\ &d \cdot [(i-r)(r-l+1) + \sum_{j=l}^r (j - l + 1)] &(i > r) \end{matrix}\right. \]

我们将 \([l,r],[r+1,n]\) 两段分别处理,易得在每一段中只有 \(i\) 不同,因此我们可以将每一段写作 \(ai^2 + bi + c\) 的形式,具体地:
\([l,r]\) 中:

\[\begin{align} &a=\frac{d}{2} \nonumber \\ &b=\frac{d}{2} \cdot (-2l + 3) \nonumber \\ &c=\frac{d}{2} \cdot (l^2-3l+2) \nonumber \end{align} \]

\([r + 1,n]\) 中:

\[\begin{align} &a=0 \nonumber \\ &b= d \cdot (r - l + 1) \nonumber \\ &c= d \cdot (\sum_{i=1}^{r - l + 1}i - r ( r - l + 1)) \nonumber \\ \end{align} \]

这样,我们就可以通过维护 \(a,b,c\) 的方式,计算出区间和。

代码

#include <iostream>
#include <algorithm>
#include <cstring>

using namespace std;
typedef long long LL;
const int N = 200005, mod = 1e9 + 7, INV2 = 500000004, INV6 = 166666668;

int tr[N * 4], la[N * 4], lb[N * 4], lc[N * 4];
int n, m;
int a[N];

int sum1(int x){
    return (LL) x * (x + 1) % mod * INV2 % mod;
}

int sum2(int x){
    return (LL) x * (x + 1) * (2 * x + 1) % mod * INV6 % mod;
}

void pushup(int u){
    tr[u] = (tr[u << 1] + tr[u << 1 | 1]) % mod;
}

void pushdown(int u, int l, int r){
    int mid = l + r >> 1;
    if (l != r){
        la[u << 1] = (la[u << 1] + la[u]) % mod;
        la[u << 1 | 1] = (la[u << 1 | 1] + la[u]) % mod;
        lb[u << 1] = (lb[u << 1] + lb[u]) % mod;
        lb[u << 1 | 1] = (lb[u << 1 | 1] + lb[u]) % mod;
        lc[u << 1] = (lc[u << 1] + lc[u]) % mod;
        lc[u << 1 | 1] = (lc[u << 1 | 1] + lc[u]) % mod;
        tr[u << 1] = (((LL) la[u] * (sum2(mid) - sum2(l - 1)) % mod + (LL) lb[u] * (sum1(mid) - sum1(l - 1)) % mod + (LL) lc[u] * (mid - l + 1) + tr[u << 1]) % mod + mod) % mod;
        tr[u << 1 | 1] = (((LL) la[u] * (sum2(r) - sum2(mid)) % mod + (LL) lb[u] * (sum1(r) - sum1(mid)) % mod + (LL) lc[u] * (r - mid) + tr[u << 1 | 1]) % mod + mod) % mod;
    }
    la[u] = lb[u] = lc[u] = 0;
}

void build(int u, int l, int r){
    if (l == r) {
        tr[u] = a[l];
        la[l] = lb[l] = lc[l];
        return ;
    }

    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

void update(int u, int l, int r, int L, int R, int a, int b, int c){
    if (L <= l && r <= R){
        tr[u] = (((LL) a * (sum2(r) - sum2(l - 1)) % mod + (LL) b * (sum1(r) - sum1(l - 1)) % mod + (LL) c * (r - l + 1) + tr[u]) % mod + mod) % mod;
        la[u] = (la[u] + a) % mod, lb[u] = (lb[u] + b) % mod, lc[u] = (lc[u] + c) % mod;
        return ;
    }

    pushdown(u, l, r);

    int mid = l + r >> 1;
    if (L <= mid) update(u << 1, l, mid, L, R, a, b, c);
    if (R > mid) update(u << 1 | 1, mid + 1, r, L, R, a, b, c);

    pushup(u);
}

int query(int u, int l, int r, int L, int R){
    if (!L) L = 1;
    if (L > R) return 0;
    if (L <= l && r <= R) return tr[u];
    pushdown(u, l, r);

    int mid = l + r >> 1, res = 0;
    if (L <= mid) res = (res + query(u << 1, l, mid, L, R)) % mod;
    if (R > mid) res = (res + query(u << 1 | 1, mid + 1, r, L, R)) % mod;

    return res;
}

void modify(int l, int r, int d){
    update(1, 1, n, l, r, (LL) d * INV2 % mod, ((LL) d * (((3 - 2 * l) % mod + mod) % mod) % mod * INV2 % mod + mod) % mod, ((LL) d * ((LL) l * l - 3 * l + 2) % mod * INV2 % mod + mod) % mod);
    if (n > r) update(1, 1, n, r + 1, n, 0, (LL) d * (r - l + 1) % mod, ((LL) d * (sum1(r - l + 1) - (LL) r * (r - l + 1) % mod) % mod + mod) % mod);
}

int get_ans(int l, int r){
    return (((LL) query(1, 1, n, n, n) * (r - l + 1) % mod - query(1, 1, n, l - 1, r - 1) - query(1, 1, n, n - r, n - l)) % mod + mod) % mod;
}

int main(){
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
    for (int i = 1; i <= n; i ++ ) a[i] = (a[i - 1] + a[i]) % mod;
    for (int i = 1; i <= n; i ++ ) a[i] = (a[i - 1] + a[i]) % mod;

    build(1, 1, n);

    while (m -- ){
        int op, l, r;
        scanf("%d%d%d", &op, &l, &r);
        if (l > r) swap(l, r);
        if (op == 1){
            int d;
            scanf("%d", &d);
            modify(l, r, d);
        }
        else printf("%d\n", get_ans(l, r));
    }

    return 0;
}
posted @ 2024-08-20 19:17  是一只小蒟蒻呀  阅读(28)  评论(0编辑  收藏  举报