P6242 【模板】线段树 3 线段树维护历史最值+区间取min
P6242 【模板】线段树 3
线段树维护历史最值+区间取min。
区间取min:
线段树维护一个区间最大值\((MaxA)\)和严格次大值\((se)\),还要维护最大值个数\(cnt\),区间和\(sum\),然后分情况:(设当前与\(k\)取min)
当\(k >= t[o].MaxA\)时,直接返回;
当\(t[o].se < k < t[o].MaxA\)时,\(t[o].sum += t[o].cnt * (k - t[o].MaxA)\),\(t[o].MaxA = k\);
当\(k <= t[o].se\)时,继续往下递归。
具体维护看代码:
void up(int o) {
if(t[ls(o)].MaxA > t[rs(o)].MaxA) {
t[o].MaxA = t[ls(o)].MaxA; t[o].cnt = t[ls(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].MaxA);
}
else if(t[ls(o)].MaxA < t[rs(o)].MaxA) {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[rs(o)].cnt;
t[o].se = max(t[rs(o)].se, t[ls(o)].MaxA);
}
else {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].se);
}
t[o].MaxB = max(t[ls(o)].MaxB, t[rs(o)].MaxB);
t[o].sum = t[ls(o)].sum + t[rs(o)].sum;
}
维护历史最大值:
要维护4个标记:最大值加减标记\((add1)\),最大值历史最大加减标记\((add1\)_ \()\),非最大值加减标记\((add2)\),非最大值历史最大加减标记\((add2\)_\()\)。
void modify(int o, int l, int r, int add1, int add1_, int add2, int add2_) {
t[o].sum += 1ll * add1 * t[o].cnt + 1ll * add2 * (r - l + 1 - t[o].cnt);
t[o].MaxB = max(t[o].MaxB, t[o].MaxA + add1_); // MaxB代表历史最大值,用add1_更新
t[o].add1_ = max(t[o].add1_, t[o].add1 + add1_); //标记也记得更新
t[o].MaxA += add1; t[o].add1 += add1;
t[o].add2_ = max(t[o].add2_, t[o].add2 + add2_);
if(t[o].se != -inf) t[o].se += add2; t[o].add2 += add2;
}
完整代码:
#include <bits/stdc++.h>
#define ls(o) (o << 1)
#define rs(o) (o << 1 | 1)
#define mid ((l + r) >> 1)
using namespace std;
inline long long read() {
long long s = 0, f = 1; char ch;
while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
return s * f;
}
const int N = 1e6 + 5, inf = 2e9;
int n, m;
long long x;
struct tree {
long long sum;
int MaxA, MaxB, cnt, se;
int add1, add1_, add2, add2_;
} t[N << 2];
void up(int o) {
if(t[ls(o)].MaxA > t[rs(o)].MaxA) {
t[o].MaxA = t[ls(o)].MaxA; t[o].cnt = t[ls(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].MaxA);
}
else if(t[ls(o)].MaxA < t[rs(o)].MaxA) {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[rs(o)].cnt;
t[o].se = max(t[rs(o)].se, t[ls(o)].MaxA);
}
else {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].se);
}
t[o].MaxB = max(t[ls(o)].MaxB, t[rs(o)].MaxB);
t[o].sum = t[ls(o)].sum + t[rs(o)].sum;
}
void build(int o, int l, int r) {
if(l == r) {
t[o].MaxA = t[o].MaxB = t[o].sum = read();
t[o].se = -inf; t[o].cnt = 1;
return ;
}
build(ls(o), l, mid); build(rs(o), mid + 1, r);
up(o);
}
void modify(int o, int l, int r, int add1, int add1_, int add2, int add2_) {
t[o].sum += 1ll * add1 * t[o].cnt + 1ll * add2 * (r - l + 1 - t[o].cnt);
t[o].MaxB = max(t[o].MaxB, t[o].MaxA + add1_);
t[o].add1_ = max(t[o].add1_, t[o].add1 + add1_);
t[o].MaxA += add1; t[o].add1 += add1;
t[o].add2_ = max(t[o].add2_, t[o].add2 + add2_);
if(t[o].se != -inf) t[o].se += add2; t[o].add2 += add2;
}
void down(int o, int l, int r) {
int tmp = max(t[ls(o)].MaxA, t[rs(o)].MaxA);
if(t[ls(o)].MaxA == tmp)
modify(ls(o), l, mid, t[o].add1, t[o].add1_, t[o].add2, t[o].add2_);
else
modify(ls(o), l, mid, t[o].add2, t[o].add2_, t[o].add2, t[o].add2_);
//add1,add1_是维护区间最大值的标记,如果这个区间没有父节点的最大值,那么最大值标记不下传
if(t[rs(o)].MaxA == tmp)
modify(rs(o), mid + 1, r, t[o].add1, t[o].add1_, t[o].add2, t[o].add2_);
else
modify(rs(o), mid + 1, r, t[o].add2, t[o].add2_, t[o].add2, t[o].add2_);
t[o].add1 = t[o].add1_ = t[o].add2 = t[o].add2_ = 0;
}
void change_add(int o, int l, int r, int x, int y, int k) {
if(x <= l && y >= r) { modify(o, l, r, k, k, k, k); return ; }
down(o, l, r);
if(x <= mid) change_add(ls(o), l, mid, x, y, k);
if(y > mid) change_add(rs(o), mid + 1, r, x, y, k);
up(o);
}
void change_min(int o, int l, int r, int x, int y, int k) {
if(t[o].MaxA <= k) return ;
if(x <= l && y >= r && t[o].MaxA > k && t[o].se < k) {
modify(o, l, r, k - t[o].MaxA, k - t[o].MaxA, 0, 0);
return ;
}
down(o, l, r);
if(x <= mid) change_min(ls(o), l, mid, x, y, k);
if(y > mid) change_min(rs(o), mid + 1, r, x, y, k);
up(o);
}
long long query_sum(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) return t[o].sum;
down(o, l, r);
long long res = 0;
if(x <= mid) res += query_sum(ls(o), l, mid, x, y);
if(y > mid) res += query_sum(rs(o), mid + 1, r, x, y);
return res;
}
int query_A(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) { return t[o].MaxA; }
down(o, l, r);
int res = -inf;
if(x <= mid) res = max(res, query_A(ls(o), l, mid, x, y));
if(y > mid) res = max(res, query_A(rs(o), mid + 1, r, x, y));
return res;
}
int query_B(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) { return t[o].MaxB; }
down(o, l, r);
int res = -inf;
if(x <= mid) res = max(res, query_B(ls(o), l, mid, x, y));
if(y > mid) res = max(res, query_B(rs(o), mid + 1, r, x, y));
return res;
}
int main() {
n = read(); m = read();
build(1, 1, n);
for(int i = 1, opt, l, r;i <= m; i++) {
opt = read(); l = read(); r = read();
if(opt == 1) x = read(), change_add(1, 1, n, l, r, x);
if(opt == 2) x = read(), change_min(1, 1, n, l, r, x);
if(opt == 3) printf("%lld\n", query_sum(1, 1, n, l, r));
if(opt == 4) printf("%d\n", query_A(1, 1, n, l, r));
if(opt == 5) printf("%d\n", query_B(1, 1, n, l, r));
}
return 0;
}