【[BJOI2018]链上二次求和】
一道挺好的数据结构题。
先用数学式子把题目里要求的答案表示出来,设\(S_i=\sum_{j=1}^{i}a_j\):
\[\text{ans}=\sum_{i=l}^{r}\sum_{j=i}^{n}S_j-S_{j-i}
\]
\[\text{ans=}\sum_{j=l}^{r}(\sum_{j=i}^{n}S_j-\sum_{j=0}^{n-i}S_j)
\]
然后设\(SS_i=\sum_{j=1}^{i}S_i\)
于是
\[\text{ans}=\sum_{j=l}^{r}(SS_n-SS_{i-1}-SS_{n-i})
\]
\[\text{ans}=(r-l+1)SS_{n}-\sum_{i=l-1}^{r-1}SS_i-\sum_{i=n-r}^{n-l}SS_i
\]
这个东西可以直接用线段树维护,复杂度一个\(\text{log}\),接着考虑修改。
如果修改的区间是\([l,r,d]\)的话,那么它对\(SS_i(i\in [1,n])\)的影响就是:
\(1:i\in[1,l)\),无影响。
\(2:i\in[l,r]\),因为\(a_i(i\in[l,r])\)增加了\(\text{d}\),所以\(S_i(i\in[l,r])\)相当于增加了\(i-l+1\)个\(\text{d}\),所以\(SS_i\)就增加了\(\sum_{j=1}^{i}d*(j-l+1)\)
\(3:i \in(r,n]\),我们可以想象一下,在\(>d\)之后\(SS_i\)的增加就固定了,所以再加上前面的贡献就是\(\sum_{j=l}^{r}d*(j-l+1)+(i-r)*d*(r-l+1)\)
然后可以把这两个式子化成一个关于\(i\)的二次函数形式(这一步上了小学的人应该都会,这里就略去了),可以在线段树上维护这个二次函数的\(\text{a,b,c}\)值,这道题就做完了。
所以除了代码细节多多多多多多多多之外也没有什么特别恶心的地方。(顺便提个醒,这题的屑数据还会出现\(l>r\),要判掉)。
My Code
#include <bits/stdc++.h>
typedef long long ll;
const int mod = 1e9 + 7;
const int inv2 = 5e8 + 4;
const int N = 2e5 + 10;
int n, m, i, j, k;
int s[N], ss[N], sum[N << 2], a[N];
int sa[N << 2], sb[N << 2], sc[N << 2];
inline int add(int a, int b) { ll c = a + b; if (c < 0) c += mod; return c >= mod ? c - mod : c; }
struct tag {
int a, b, c;
tag() { a = b = c = 0; }
tag(int _a, int _b, int _c) { a = _a, b = _b, c = _c; }
inline void init() {
a = add(a, 0);
b = add(b, 0);
c = add(c, 0);
}
} laz[N << 2], mdy;
inline void down(int u, tag t) {
int a = t.a, b = t.b, c = t.c;
sum[u] = add(add(add(sum[u], 1ll * a * sa[u] % mod), 1ll * b * sb[u] % mod), 1ll * sc[u] * c % mod);
laz[u].a = add(laz[u].a, a);
laz[u].b = add(laz[u].b, b);
laz[u].c = add(laz[u].c, c);
}
inline void push_down(int u) {
if (laz[u].a || laz[u].b || laz[u].c) {
down(u << 1, laz[u]);
down(u << 1 | 1, laz[u]);
laz[u] = tag(0, 0, 0);
}
}
void build(int l, int r, int u) {
if (l == r) {
sum[u] = ss[l], sc[u] = 1, sb[u] = l, sa[u] = 1ll * l * l % mod;
return;
}
int mid = (l + r) >> 1;
build(l, mid, u << 1);
build(mid + 1, r, u << 1 | 1);
sum[u] = add(sum[u << 1], sum[u << 1 | 1]);
sa[u] = add(sa[u << 1], sa[u << 1 | 1]);
sb[u] = add(sb[u << 1], sb[u << 1 | 1]);
sc[u] = add(sc[u << 1], sc[u << 1 | 1]);
return;
}
void modify(int ml, int mr, int l, int r, int u) {
if (ml <= l && r <= mr) { down(u, mdy); return; }
int mid = (l + r) >> 1;
push_down(u);
if (ml <= mid) modify(ml, mr, l, mid, u << 1);
if (mid < mr) modify(ml, mr, mid + 1, r, u << 1 | 1);
sum[u] = add(sum[u << 1], sum[u << 1 | 1]);
}
int query(int ql, int qr, int l, int r, int u) {
if (ql > qr) return 0; if (ql > r || qr < l) return 0;
if (ql <= l && r <= qr) return sum[u];
push_down(u);
int mid = (l + r) >> 1, res = 0;
if (ql <= mid) res = add(res, query(ql, qr, l, mid, u << 1));
if (mid < qr) res = add(res, query(ql, qr, mid + 1, r, u << 1 | 1));
return res;
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", a + i), s[i] = add(s[i - 1], a[i]), ss[i] = add(ss[i - 1], s[i]);
build(1, n, 1);
for (int i = 1, opt, l, r, z; i <= m; i++) {
scanf("%d %d %d", &opt, &l, &r);
if (l > r) std::swap(l, r);
if (opt == 1) {
scanf("%d", &z); int len = r - l + 1, a = (ll)z * inv2 % mod;
mdy = tag(a, (ll)a * (3 - 2 * l + mod) % mod,((ll)l * l - 3 * l + 2) % mod * a % mod);
modify(l, r, 1, n, 1);
if (r == n) continue;
mdy = tag(0, (ll)len * z % mod, ((ll)len * (len + 1) / 2 % mod * z % mod - (ll)len * r % mod * z % mod + mod) % mod);
modify(r + 1, n, 1, n, 1);
} else {
int len = r - l + 1;
ll ans = ((ll)len * query(n, n, 1, n, 1) % mod - query(std::max(l - 1, 1), r - 1, 1, n, 1) + mod - query(std::max(n - r, 1), n - l, 1, n, 1) + mod) % mod;
printf("%lld\n", ans);
}
}
return 0;
}