[lnsyoj2286/luoguP4458/BJOI2018]链上二次求和
题意
给定序列 \(a\),要求支持修改与查询操作:修改操作为对区间 \(l,r\) 的每个数 \(+d\),查询操作为给定区间 \(l,r\),要求查询:
\[\sum_{len=l}^r\sum_{i=l}^{n-len+1}\sum_{j=l}^{i+len-1}a_j
\]
sol
化简式子(下设 \(sum_i=\sum_{j=1}^i a_j,ssum_i=\sum_{j=1}^i sum_j\)):
\[\begin{align} \nonumber
&\sum_{len=l}^r\sum_{i=l}^{n-len+1}\sum_{j=i}^{i+len-1}a_j \\ \nonumber
=&\sum_{len=l}^r\sum_{i=l}^{n-len+1}(sum_{i+len-1}-sum_{i-1}) \\ \nonumber
=&\sum_{len=l}^r(\sum_{i=l}^{n-len+1}sum_{i+len-1})-(\sum_{i=l}^{n-len+1}sum_{i-1}) \\ \nonumber
=&\sum_{len=l}^r (ssum_n-ssum_{len-1}-ssum_{n-len}) \\ \nonumber
=&ssum_n \cdot (r - l + 1) - \sum_{len=l}^r ssum_{len-1} - \sum_{len=l}^r ssum_{n-len} \nonumber
\end{align}
\]
因此,我们只需要维护带有区间修改的前缀前缀和,并支持区间和查询即可。
易得,区间修改 \([l,r]\) 会对 \(sum_i\) 带来的贡献为
\[\Delta sum_i=
\left \{ \begin{matrix}
&0&(i<l) \\
&d \cdot (i - l + 1) &(l \le i \le r)\\
&d \cdot (r - l + 1) &(i > r)
\end{matrix}\right.
\]
从而可得,区间修改 \([l, r]\) 会对 \(ssum_i\) 带来的贡献为
\[\Delta ssum_i=
\left \{ \begin{matrix}
&0&(i<l) \\
&d \cdot \sum_{j=l}^i (j - l + 1) &(l \le i \le r)\\
&d \cdot [(i-r)(r-l+1) + \sum_{j=l}^r (j - l + 1)] &(i > r)
\end{matrix}\right.
\]
我们将 \([l,r],[r+1,n]\) 两段分别处理,易得在每一段中只有 \(i\) 不同,因此我们可以将每一段写作 \(ai^2 + bi + c\) 的形式,具体地:
在 \([l,r]\) 中:
\[\begin{align}
&a=\frac{d}{2} \nonumber \\
&b=\frac{d}{2} \cdot (-2l + 3) \nonumber \\
&c=\frac{d}{2} \cdot (l^2-3l+2) \nonumber
\end{align}
\]
在 \([r + 1,n]\) 中:
\[\begin{align}
&a=0 \nonumber \\
&b= d \cdot (r - l + 1) \nonumber \\
&c= d \cdot (\sum_{i=1}^{r - l + 1}i - r ( r - l + 1)) \nonumber \\
\end{align}
\]
这样,我们就可以通过维护 \(a,b,c\) 的方式,计算出区间和。
代码
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 200005, mod = 1e9 + 7, INV2 = 500000004, INV6 = 166666668;
int tr[N * 4], la[N * 4], lb[N * 4], lc[N * 4];
int n, m;
int a[N];
int sum1(int x){
return (LL) x * (x + 1) % mod * INV2 % mod;
}
int sum2(int x){
return (LL) x * (x + 1) * (2 * x + 1) % mod * INV6 % mod;
}
void pushup(int u){
tr[u] = (tr[u << 1] + tr[u << 1 | 1]) % mod;
}
void pushdown(int u, int l, int r){
int mid = l + r >> 1;
if (l != r){
la[u << 1] = (la[u << 1] + la[u]) % mod;
la[u << 1 | 1] = (la[u << 1 | 1] + la[u]) % mod;
lb[u << 1] = (lb[u << 1] + lb[u]) % mod;
lb[u << 1 | 1] = (lb[u << 1 | 1] + lb[u]) % mod;
lc[u << 1] = (lc[u << 1] + lc[u]) % mod;
lc[u << 1 | 1] = (lc[u << 1 | 1] + lc[u]) % mod;
tr[u << 1] = (((LL) la[u] * (sum2(mid) - sum2(l - 1)) % mod + (LL) lb[u] * (sum1(mid) - sum1(l - 1)) % mod + (LL) lc[u] * (mid - l + 1) + tr[u << 1]) % mod + mod) % mod;
tr[u << 1 | 1] = (((LL) la[u] * (sum2(r) - sum2(mid)) % mod + (LL) lb[u] * (sum1(r) - sum1(mid)) % mod + (LL) lc[u] * (r - mid) + tr[u << 1 | 1]) % mod + mod) % mod;
}
la[u] = lb[u] = lc[u] = 0;
}
void build(int u, int l, int r){
if (l == r) {
tr[u] = a[l];
la[l] = lb[l] = lc[l];
return ;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int L, int R, int a, int b, int c){
if (L <= l && r <= R){
tr[u] = (((LL) a * (sum2(r) - sum2(l - 1)) % mod + (LL) b * (sum1(r) - sum1(l - 1)) % mod + (LL) c * (r - l + 1) + tr[u]) % mod + mod) % mod;
la[u] = (la[u] + a) % mod, lb[u] = (lb[u] + b) % mod, lc[u] = (lc[u] + c) % mod;
return ;
}
pushdown(u, l, r);
int mid = l + r >> 1;
if (L <= mid) update(u << 1, l, mid, L, R, a, b, c);
if (R > mid) update(u << 1 | 1, mid + 1, r, L, R, a, b, c);
pushup(u);
}
int query(int u, int l, int r, int L, int R){
if (!L) L = 1;
if (L > R) return 0;
if (L <= l && r <= R) return tr[u];
pushdown(u, l, r);
int mid = l + r >> 1, res = 0;
if (L <= mid) res = (res + query(u << 1, l, mid, L, R)) % mod;
if (R > mid) res = (res + query(u << 1 | 1, mid + 1, r, L, R)) % mod;
return res;
}
void modify(int l, int r, int d){
update(1, 1, n, l, r, (LL) d * INV2 % mod, ((LL) d * (((3 - 2 * l) % mod + mod) % mod) % mod * INV2 % mod + mod) % mod, ((LL) d * ((LL) l * l - 3 * l + 2) % mod * INV2 % mod + mod) % mod);
if (n > r) update(1, 1, n, r + 1, n, 0, (LL) d * (r - l + 1) % mod, ((LL) d * (sum1(r - l + 1) - (LL) r * (r - l + 1) % mod) % mod + mod) % mod);
}
int get_ans(int l, int r){
return (((LL) query(1, 1, n, n, n) * (r - l + 1) % mod - query(1, 1, n, l - 1, r - 1) - query(1, 1, n, n - r, n - l)) % mod + mod) % mod;
}
int main(){
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);
for (int i = 1; i <= n; i ++ ) a[i] = (a[i - 1] + a[i]) % mod;
for (int i = 1; i <= n; i ++ ) a[i] = (a[i - 1] + a[i]) % mod;
build(1, 1, n);
while (m -- ){
int op, l, r;
scanf("%d%d%d", &op, &l, &r);
if (l > r) swap(l, r);
if (op == 1){
int d;
scanf("%d", &d);
modify(l, r, d);
}
else printf("%d\n", get_ans(l, r));
}
return 0;
}