Petrozavodsk Winter-2018. AtCoder Contest. Problem I. ADD, DIV, MAX 吉司机线段树
题意:给你一个序列,需要支持以下操作:1:区间内的所有数加上某个值。2:区间内的所有数除以某个数(向下取整)。3:询问某个区间内的最大值。
思路(从未见过的套路):维护区间最大值和区间最小值,执行2操作时,继续向下寻找子区间,如果区间满足:min - (min / x) == max - (max / x)时,给这个区间内的所有数减去min - (min / x)就可以了。为什么这样做呢?因为向下取整操作变化速度远快于加法,在经过很多次操作后其实序列中的数区域相等,复杂度需要用势能分析之类的,均摊复杂度应该是O(n * (log(n) ^ 2))。
代码:
#include <bits/stdc++.h> #define LL long long #define ls (o << 1) #define rs (o << 1 | 1) using namespace std; const int maxn = 200010; struct Seg { LL add, mx, mi; }; Seg tr[maxn * 4]; LL a[maxn]; void pushup(int o) { tr[o].mx = max(tr[ls].mx, tr[rs].mx); tr[o].mi = min(tr[ls].mi, tr[rs].mi); } void pushdown(int o) { if(tr[o].add != 0) { tr[ls].add += tr[o].add; tr[ls].mi += tr[o].add; tr[ls].mx += tr[o].add; tr[rs].add += tr[o].add; tr[rs].mi += tr[o].add; tr[rs].mx += tr[o].add; tr[o].add = 0; } } void dfs(int o, int l, int r, LL val) { if(tr[o].mi - (tr[o].mi / val) == tr[o].mx - (tr[o].mx / val)) { LL tmp = tr[o].mi - (tr[o].mi / val); tr[o].add -= tmp; tr[o].mi -= tmp; tr[o].mx -= tmp; return; } int mid = (l + r) >> 1; pushdown(o); dfs(ls, l, mid, val); dfs(rs, mid + 1, r, val); pushup(o); } void build(int o, int l, int r) { if(l == r) { tr[o].add = 0; tr[o].mx = tr[o].mi = a[l]; return; } int mid = (l + r) >> 1; build(ls, l, mid); build(rs, mid + 1, r); pushup(o); } void update(int o, int l, int r, int ql, int qr, LL val, bool flag) { if(l >= ql && r <= qr) { if(flag == 0) { tr[o].mi += val; tr[o].mx += val; tr[o].add += val; } else { dfs(o, l, r, val); } return; } pushdown(o); int mid = (l + r) >> 1; if(ql <= mid) update(ls, l, mid, ql, qr, val, flag); if(qr > mid) update(rs, mid + 1, r, ql, qr, val, flag); pushup(o); } LL query(int o, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { return tr[o].mx; } int mid = (l + r) >> 1; LL ans = 0; pushdown(o); if(ql <= mid) ans = max(ans, query(ls, l, mid, ql, qr)); if(qr > mid) ans = max(ans, query(rs, mid + 1, r, ql, qr)); return ans; } int main() { int op, l, r, x, n, m; scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) { scanf("%lld", &a[i]); } build(1, 1, n); for (int i = 1; i <= m; i++) { scanf("%d", &op); if(op == 0) { scanf("%d%d%d", &l, &r, &x); l++, r++; update(1, 1, n, l, r, x, 0); } else if(op == 1) { scanf("%d%d%d", &l, &r, &x); l++, r++; if(x != 1) update(1, 1, n, l, r, x, 1); } else { scanf("%d%d%d", &l, &r, &x); l++, r++; printf("%lld\n", query(1, 1, n, l, r)); } } }