【学习笔记】Segment Tree Beats/吉司机线段树

链接

区间最值操作

HDU-5306

支持对区间取 \(\min\),维护区间 \(\max\),查询区间和。

很容易想到一个暴力,我们每一次找出这个区间的最大值 \(mx\),如果 \(mx>x\),那么暴力修改这个位置的值,否则已经修改完毕,退出,时间复杂度为 \(O(n^2 \log n)\)

打一打补丁,对线段树上的每一个区间维护区间最大值 \(mx\),这个区间中最大值出现的次数 \(t\),区间次大值 \(se\),当然还要维护区间和 \(sum\)

现在考虑打上区间取 \(\min\) 标记

  • 如果 \(mx\le x\),那么对 \(sum\) 就没有修改。
  • 如果 \(se<x<mx\),那么 \(sum=sum-(mx-x)\times t\)
  • 如果 \(x\le se<mx\),此时无法直接更新节点信息,故向下左右子树递归。我们分别 DFS 这个节点的两个孩子,如果当前 DFS 的过程中遇到了前两种情况,就直接修改打上标记然后退出,否则就继续 DFS。
点击查看代码
#include <bits/stdc++.h>
//#define int long long
#define rep(i, l, r) for(int i = l; i <= r; ++ i)
#define per(i, r, l) for(int i = r; i >= l; -- i)
using namespace std;
char gc()
{
    static char buf[1 << 20], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 20, stdin), p1 == p2) ? EOF : *p1 ++;
}
void read(int &n)
{
    char c = gc(), w = 0;
    for(; c < '0' || c > '9'; c = gc()) w = c == '-';
    for(n = 0; c >= '0' && c <= '9'; c = gc()) n = (n << 1) + (n << 3) + c - 48;
    n = w ? -n : n;
}
const int N = 1e6 + 10;
#define ls rt + rt
#define rs rt + rt + 1
int mx[N << 2], se[N << 2], cnt[N << 2], tag[N << 2];
long long sum[N << 2];
int p[N], n, m;
void update(int rt)
{
    sum[rt] = sum[ls] + sum[rs];
    if(mx[ls] == mx[rs]) 
    {
        mx[rt] = mx[ls];
        cnt[rt] = cnt[ls] + cnt[rs];
        se[rt] = max(se[ls], se[rs]);
    }
    else if(mx[ls] < mx[rs])
    {
        mx[rt] = mx[rs];
        cnt[rt] = cnt[rs];
        se[rt] = max(mx[ls], se[rs]);
    }
    else
    {
        mx[rt] = mx[ls];
        cnt[rt] = cnt[ls];
        se[rt] = max(se[ls], mx[rs]);
    }
}
void build(int rt, int l, int r)
{
    tag[rt] = 0;
    if(l == r)
    {
        mx[rt] = sum[rt] = p[l];
        se[rt] = -1; cnt[rt] = 1;
        return ;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid); build(rs, mid + 1, r);
    update(rt);
}
void push_up(int rt, int v)
{
    if(mx[rt] <= v) return;
    sum[rt] -= 1ll * (mx[rt] - v) * cnt[rt];
    mx[rt] = tag[rt] = v;
}
void push_down(int rt)
{
    if(!tag[rt]) return;
    push_up(ls, tag[rt]); push_up(rs, tag[rt]);
    tag[rt] = 0;
}
void change(int rt, int l, int r, int x, int y, int v)
{
    if(mx[rt] <= v) return;
    if(x <= l && r <= y && se[rt] <= v) {push_up(rt, v); return;}
    int mid = (l + r) >> 1; push_down(rt);
    if(mid >= x) change(ls, l, mid, x, y, v); 
    if(mid < y) change(rs, mid + 1, r, x, y, v);
    update(rt);
}
int askmax(int rt, int l, int r, int x, int y)
{
    if(x <= l && r <= y) return mx[rt];
    push_down(rt);
    int mid = (l + r) >> 1, ans = -1;
    if(mid >= x) ans = max(ans, askmax(ls, l, mid, x, y));
    if(mid < y) ans = max(ans, askmax(rs, mid + 1, r, x, y));
    return ans;
}
long long asksum(int rt, int l, int r, int x, int y)
{
    if(x <= l && r <= y) return sum[rt];
    push_down(rt);
    int mid = (l + r) >> 1; long long ans = 0;
    if(mid >= x) ans += asksum(ls, l, mid, x, y);
    if(mid < y) ans += asksum(rs, mid + 1, r, x, y);
    return ans;
}
void solve()
{
    read(n); read(m);
    rep(i, 1, n) read(p[i]);
    build(1, 1, n);
    rep(i, 1, m)
    {
        int opt, x, y, t;
        read(opt); read(x); read(y);
        if(!opt) read(t), change(1, 1, n, x, y, t);
        else if(opt == 1) printf("%d\n", askmax(1, 1, n, x, y));
        else printf("%lld\n", asksum(1, 1, n, x, y));
    }
}
main()
{
    int T; read(T); for(; T; -- T) solve();
    return 0;
}

最假女选手 BZOJ - 4695/P10639 BZOJ4695 最佳女选手

支持区间加,区间取 \(\min/\max\),求区间和,区间 \(\min/\max\)

别乱交,洛谷的时间是 \(2s\)

这题和上道题目的区别是多了区间加,区间 \(\max\)

我们需要多维护 \(sumtag\)\(maxtag\),对于 \(tag\) 下传的部分,最优先处理区间加的标记,两个区间最值的标记优先度相等。

有点复杂,代码看看就行。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

#define ls rt << 1
#define rs rt << 1 | 1
#define int long long
#define rep(i, l, r) for(int i = l; i <= r; ++ i)
#define per(i, r, l) for(int i = r; i >= l; -- i)
const int N = 5e5 + 6;
const int INF = 2e9;
int n, Q, a[N];
int mx[N << 2], mx2[N << 2], mn[N << 2], mn2[N << 2], mxcnt[N << 2], mncnt[N << 2], mxtag[N << 2], mntag[N << 2], sumtag[N << 2], sum[N << 2], ll[N << 2], rr[N << 2];
void update(int rt) {
    sum[rt] = sum[ls] + sum[rs];
    if(mx[ls] == mx[rs]) {
        mx[rt] = mx[ls], mx2[rt] = max(mx2[ls], mx2[rs]);
        mxcnt[rt] = mxcnt[ls] + mxcnt[rs];
    }
    else if(mx[ls] > mx[rs]) {
        mx[rt] = mx[ls], mx2[rt] = max(mx2[ls], mx[rs]);
        mxcnt[rt] = mxcnt[ls];
    }
    else {
        mx[rt] = mx[rs], mx2[rt] = max(mx[ls], mx2[rs]);
        mxcnt[rt] = mxcnt[rs];
    }
    if(mn[ls] == mn[rs]) {
        mn[rt] = mn[ls], mn2[rt] = min(mn2[ls], mn2[rs]);
        mncnt[rt] = mncnt[ls] + mncnt[rs];
    }
    else if(mn[ls] < mn[rs]) {
        mn[rt] = mn[ls], mn2[rt] = min(mn2[ls], mn[rs]);
        mncnt[rt] = mncnt[ls];
    }
    else {
        mn[rt] = mn[rs], mn2[rt] = min(mn[ls], mn2[rs]);
        mncnt[rt] = mncnt[rs];
    }
}
void pushadd(int rt, int v) {
    sum[rt] += (rr[rt] - ll[rt] + 1) * v;
    mx[rt] += v, mn[rt] += v;
    if(mx2[rt] != -INF) mx2[rt] += v;
    if(mn2[rt] != INF) mn2[rt] += v;
    if(mxtag[rt] != -INF) mxtag[rt] += v;
    if(mntag[rt] != INF) mntag[rt] += v;
    sumtag[rt] += v;
}
void pushmax(int rt, int v) {
    if(mn[rt] > v) return;
    sum[rt] += (v - mn[rt]) * mncnt[rt];
    if(mx2[rt] == mn[rt]) mx2[rt] = v;
    if(mx[rt] == mn[rt]) mx[rt] = v;
    if(mntag[rt] < v) mntag[rt] = v;
    mn[rt] = v, mxtag[rt] = v;
}
void pushmin(int rt, int v) {
    if(mx[rt] <= v) return;
    sum[rt] += (v - mx[rt]) * mxcnt[rt];
    if(mn2[rt] == mx[rt]) mn2[rt] = v;
    if(mn[rt] == mx[rt]) mn[rt] = v;
    if(mxtag[rt] > v) mxtag[rt] = v;
    mx[rt] = v, mntag[rt] = v;
}
void pushdown(int rt) {    
    if(sumtag[rt]) {
        pushadd(ls, sumtag[rt]);
        pushadd(rs, sumtag[rt]);
        sumtag[rt] = 0;
    }
    if(mxtag[rt] != -INF) {
        pushmax(ls, mxtag[rt]);
        pushmax(rs, mxtag[rt]);
        mxtag[rt] = -INF;
    }
    if(mntag[rt] != INF) {
        pushmin(ls, mntag[rt]);
        pushmin(rs, mntag[rt]);
        mntag[rt] = INF;
    }
}
void build(int rt, int l, int r) {
    ll[rt] = l; rr[rt] = r;
    mntag[rt] = INF; mxtag[rt] = -INF;
    if(l == r) {
        sum[rt] = mx[rt] = mn[rt] = a[l];
        mx2[rt] = -INF, mn2[rt] = INF;
        mxcnt[rt] = mncnt[rt] = 1;
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid); build(rs, mid + 1, r);
    update(rt);
}
void change(int rt, int x, int y, int v, int op) {
    int l = ll[rt], r = rr[rt];
    if(op == 2 && mn[rt] >= v) return;
    if(op == 3 && mx[rt] <= v) return;
    if(x <= l && r <= y) {
        bool fl = 0;
        if(op == 1) pushadd(rt, v), fl = 1;
        if(op == 2 && mn2[rt] > v) pushmax(rt, v), fl = 1;
        if(op == 3 && mx2[rt] < v) pushmin(rt, v), fl = 1;
        if(fl) return;
    }
    int mid = (l + r) >> 1;
    pushdown(rt);
    if(x <= mid) change(ls, x, y, v, op);
    if(y > mid) change(rs, x, y, v, op);
    update(rt);
}
int ask(int rt, int x, int y, int op) {
    int l = ll[rt], r = rr[rt];
    if(x <= l && r <= y) {
        int xx;
        if(op == 6) xx = mn[rt];
        else if(op == 5) xx = mx[rt];
        else xx = sum[rt];
        return xx;
    }
    int mid = (l + r) >> 1, res;
    pushdown(rt);
    if(op == 6) {
        res = INF;
        if(x <= mid) res = min(res, ask(ls, x, y, op));
        if(y > mid) res = min(res, ask(rs, x, y, op));
    }
    else if(op == 5) {
        res = -INF;
        if(x <= mid) res = max(res, ask(ls, x, y, op));
        if(y > mid) res = max(res, ask(rs, x, y, op));
    }
    else {
        res = 0;
        if(x <= mid) res += ask(ls, x, y, op);
        if(y > mid) res += ask(rs, x, y, op);
    }
    return res;
}
main() {
    scanf("%lld", &n); rep(i, 1, n) scanf("%lld", &a[i]); build(1, 1, n);
    scanf("%lld", &Q);
    for(; Q; -- Q) {
        int op, l, r, v; scanf("%lld%lld%lld", &op, &l, &r);
        if(op <= 3) scanf("%lld", &v), change(1, l, r, v, op);
        else printf("%lld\n", ask(1, l, r, op));
    }
}

历史最值操作

Tyvj 1518 CPU监控/bzoj3060/P4314 CPU监控

支持区间加,区间覆盖,查询区间 \(\max\) 与区间历史 \(\max\)

一开始的想法是用当前最大值及当前的标记维护历史最大值,但可以举出反例:

8
1 1 10 1 1 1 1 1
3
P 1 8 100
P 1 8 -10
A 3 3

正确答案是 \(110\),但输出 \(100\),因为第二个操作把会减小第一个操作的 add 标记,儿子的历史 \(\max\) 得不到及时的更新。
所以记录以下信息:

mx: 区间最大值
hmx:区间历史最大值
vis:是否进行过赋值操作
sum:当前节点上次 push_down 之后的加和
asg: 当前节点上次 push_down 之后的赋值操作 (赋值之后的区间加操作算入赋值)
summx:上次 push_down 之后达到的最大加和
asgmx:上次 push_down 之后赋的最大值

点击查看代码
#include <bits/stdc++.h>
//#define int long long
#define rep(i, l, r) for(int i = l; i <= r; ++ i)
#define per(i, r, l) for(int i = r; i >= l; -- i)
using namespace std;

#define gc getchar
void read(int &n)
{
    char c = gc(), w = 0;
    for(; c < '0' || c > '9'; c = gc()) w = c == '-';
    for(n = 0; c >= '0' && c <= '9'; c = gc()) n = (n << 1) + (n << 3) + c - 48;
    n = w ? -n : n;
}
const int N = 1e5 + 10, INF = 0x3f3f3f3f;
int n, m;
int p[N];
#define ls rt + rt
#define rs rt + rt + 1
int mx[N << 2], hmx[N << 2], sum[N << 2], asg[N << 2], summx[N << 2], asgmx[N << 2];
bool vis[N << 2];
void update(int rt)
{
    mx[rt] = max(mx[ls], mx[rs]);
    hmx[rt] = max(hmx[ls], hmx[rs]);
}
void build(int rt, int l, int r)
{
    if(l == r)
    {
        mx[rt] = hmx[rt] = p[l];
        return;
    }
    int mid = (l + r) >> 1; build(ls, l, mid); build(rs, mid + 1, r);
    update(rt);
}
void push_up_add(int rt, int val, int mxval)
{
    if(vis[rt]) asgmx[rt] = max(asgmx[rt], asg[rt] + mxval), asg[rt] += val;
    else summx[rt] = max(summx[rt], sum[rt] + mxval), sum[rt] += val;
    hmx[rt] = max(hmx[rt], mx[rt] + mxval);
    mx[rt] += val;
}
void push_up_change(int rt, int val, int mxval)
{
    hmx[rt] = max(hmx[rt], mxval);
    if(vis[rt]) asgmx[rt] = max(asgmx[rt], mxval);
    else vis[rt] = 1, asgmx[rt] = mxval;
    mx[rt] = asg[rt] = val;
}
void push_down(int rt)
{
    push_up_add(ls, sum[rt], summx[rt]);
    push_up_add(rs, sum[rt], summx[rt]);
    sum[rt] = summx[rt] = 0;
    if(!vis[rt]) return;
    push_up_change(ls, asg[rt], asgmx[rt]);
    push_up_change(rs, asg[rt], asgmx[rt]);
    vis[rt] = asg[rt] = asgmx[rt] = 0;
}
void add(int rt, int l, int r, int x, int y, int val)
{
    if(x <= l && r <= y) {push_up_add(rt, val, val); return;}
    int mid = (l + r) >> 1; push_down(rt);
    if(mid >= x) add(ls, l, mid, x, y, val);
    if(mid < y) add(rs, mid + 1, r, x, y, val);
    update(rt);
}
void change(int rt, int l, int r, int x, int y, int val)
{
    if(x <= l && r <= y) {push_up_change(rt, val, val); return;}
    int mid = (l + r) >> 1; push_down(rt);
    if(mid >= x) change(ls, l, mid, x, y, val);
    if(mid < y) change(rs, mid + 1, r, x, y, val);
    update(rt);
}
int History(int rt, int l, int r, int x, int y)
{
    if(x <= l && r <= y) return hmx[rt];
    int mid = (l + r) >> 1, ans = -INF; push_down(rt);
    if(mid >= x) ans = max(ans, History(ls, l, mid, x, y));
    if(mid < y) ans = max(ans, History(rs, mid + 1, r, x, y));
    return ans;
}
int Max(int rt, int l, int r, int x, int y)
{
    if(x <= l && r <= y) return mx[rt];
    int mid = (l + r) >> 1, ans = -INF; push_down(rt);
    if(mid >= x) ans = max(ans, Max(ls, l, mid, x, y));
    if(mid < y) ans = max(ans, Max(rs, mid + 1, r, x, y));
    return ans;
}
main()
{
    read(n);
    rep(i, 1, n) read(p[i]);
    build(1, 1, n); read(m);
    rep(i, 1, m)
    {
        char ch; cin >> ch;
        int x, y, t;
        read(x); read(y);
        if(ch == 'Q') printf("%d\n", Max(1, 1, n, x, y));
        if(ch == 'A') printf("%d\n", History(1, 1, n, x, y));
        if(ch == 'P') read(t), add(1, 1, n, x, y, t);
        if(ch == 'C') read(t), change(1, 1, n, x, y, t);
    }
    return 0;
}

杂题

Abl_e

\(n\) 个数,初始均为 \(1\)\(q\) 次修改,每次修改把所有 \([l,r]\) 内的数都修改成 \(x(1\le x\le 9)\),问每次修改后输出所有的数从左至右组成的十进制的数与 \(998244353\) 取余。\(1\le n,q\le 2\times 10^5\)

显然易见,修改操作就是加值操作(\(x \times \sum^{r}_{i=l}10^{n-i}\)),使用线段树区间维护即可。

注意 lazy 标记的时候取模操作。

点击查看代码
#include <bits/stdc++.h>
#define int long long
#define rep(i, l, r) for(int i = l; i <= r; ++ i)
#define per(i, r, l) for(int i = r; i >= l; -- i)
using namespace std;

#define gc getchar
void read(int &n)
{
    char c = gc(), w = 0;
    for(; c < '0' || c > '9'; c = gc()) w = c == '-';
    for(n = 0; c >= '0' && c <= '9'; c = gc()) n = (n << 1) + (n << 3) + c - 48;
    n = w ? -n : n;
}
const int N = 2e5 + 10, INF = 0x3f3f3f3f, mod = 998244353;
int n, m;
int p[N << 2], num[N << 2];
#define ls rt + rt
#define rs rt + rt + 1
int sum[N << 2], tag[N << 2];
void update(int rt, int l, int r)
{
    int mid = (l + r) >> 1;
    sum[rt] = (sum[ls] % mod * p[(r - mid)] % mod + sum[rs]) % mod;
}
void build(int rt, int l, int r)
{
    tag[rt] = -1;
    if(l == r)
    {
        sum[rt] = 1;
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid); build(rs, mid + 1, r);
    update(rt, l, r);
}
void push_up(int rt, int l, int r, int v)
{
    sum[rt] = num[r - l + 1] % mod * v % mod;
    tag[rt] = v;
}
void push_down(int rt, int l, int r)
{
    if(tag[rt] == -1) return;
    int mid = (l + r) >> 1;
    push_up(ls, l, mid, tag[rt]); push_up(rs, mid + 1, r, tag[rt]);
    tag[rt] = -1;
}
void change(int rt, int l, int r, int x, int y, int v)
{
    if(x <= l && r <= y) {push_up(rt, l, r, v); return;}
    int mid = (l + r) >> 1; push_down(rt, l, r);
    if(mid >= x) change(ls, l, mid, x, y, v);
    if(mid < y) change(rs, mid + 1, r, x, y, v);
    update(rt, l, r);
}
main()
{
    read(n); read(m);
    p[0] = 1;
    rep(i, 1, n) p[i] = p[i - 1] * 10 % mod;
    rep(i, 1, n) num[i] = (num[i - 1] * 10 + 1) % mod;
    build(1, 1, n);
    rep(i, 1, m)
    {
        int l, r, x;
        read(l); read(r); read(x);
        change(1, 1, n, l, r, x);
        printf("%d\n", sum[1]);
    }
}

Greedy Shopping

给定一个 \(n\)非升序列 \(a\), 第一个操作修改 \(\forall i\in[1,x],a_i=\max(a_i,y)\),第二个操作询问从下标 \(x\) 开始,从左往右访问序列 \(a\) ,如果 \(a_i\le y\) ,则 \(\text{answer}++,y=y-a_i\)

用线段树二分出最大的小于 \(y\) 的数在哪里,然后查询修改即可。

点击查看代码
#include <bits/stdc++.h>
#define int long long
#define rep(i, l, r) for(int i = l; i <= r; ++ i)
#define per(i, r, l) for(int i = r; i >= l; -- i)
using namespace std;
const int N = 8e5 + 10;
int n, m, mn[N], sum[N], ll[N], rr[N], a[N], tag[N];
#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define len(x) (rr[x] - ll[x] + 1)
void update(int rt)
{
    mn[rt] = min(mn[ls], mn[rs]);
    sum[rt] = sum[ls] + sum[rs];
}
void build(int rt, int l, int r)
{
    tag[rt] = -1;
    ll[rt] = l; rr[rt] = r;
    if(l == r)
    {
        mn[rt] = sum[rt] = a[l];
        return ;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid); build(rs, mid + 1, r);
    update(rt);
}
void push_up(int rt, int k)
{
    mn[rt] = tag[rt] = k;
    sum[rt] = len(rt) * k;
}
void push_down(int rt)
{
    if(tag[rt] == -1) return ;
    push_up(ls, tag[rt]);
    push_up(rs, tag[rt]);
    tag[rt] = -1;
}
void change(int rt, int x, int y, int v)
{
    int l = ll[rt], r = rr[rt];
    if(x <= l && r <= y) {push_up(rt, v); return;}
    int mid = (l + r) >> 1; push_down(rt);
    if(mid >= x) change(ls, x, y, v);
    if(mid < y) change(rs, x, y, v);
    update(rt);
}
int ask_id(int rt, int k)
{
    int l = ll[rt], r = rr[rt];
    if(l == r) return l + !(sum[rt] <= k);
    int mid = (l + r) >> 1;
    push_down(rt);
    if(mn[ls] <= k) return ask_id(ls, k);
    return ask_id(rs, k);
}
int ask(int rt, int x, int &k)
{
    int l = ll[rt], r = rr[rt];
    if(x > r || mn[rt] > k) return 0;
    if(l >= x && sum[rt] <= k)
    {
        k -= sum[rt];
        return len(rt);
    }
    int mid = (l + r) >> 1; push_down(rt);
    return ask(ls, x, k) + ask(rs, x, k);
}
main()
{
    cin >> n >> m;
    rep(i, 1, n) cin >> a[i];
    memset(mn, 0x3f, sizeof mn); build(1, 1, n);
    rep(i, 1, m)
    {
        int opt, x, y;
        cin >> opt >> x >> y;
        if(opt == 1)
        {
            int v = ask_id(1, y);
            if(v <= x) change(1, v, x, y);
        }
        else cout << ask(1, x, y) << '\n';
    }
    return 0;
}

序列操作 洛谷 - P2572

给你一个 \(01\) 字符串,有变 \(0\),变 \(1\),异或操作,查询连续 \(1\) 的个数或 \(1\) 的个数。

sum:这段区间内 \(1\) 的总数
mx0:这段区间内连续 \(0\) 的最大长度
mx1:这段区间内连续 \(1\) 的最大长度
lmx0:这段区间从左端点连续 \(0\) 的最大长度
lmx1:这段区间从左端点连续 \(1\) 的最大长度
rmx0:这段区间从右端点连续 \(0\) 的最大长度
rmx1:这段区间从右端点连续 \(1\) 的最大长度
op:翻转标记
tag:下推标记(记录是否全赋为1或0)

update 需要维护这些东西。

void update(int rt)
{
    mx0[rt] = max({mx0[ls], mx0[rs], rmx0[ls] + lmx0[rs]});
    mx1[rt] = max({mx1[ls], mx1[rs], rmx1[ls] + lmx1[rs]});
    lmx0[rt] = lmx0[ls]; if(lmx0[ls] == len(ls)) lmx0[rt] += lmx0[rs];
    lmx1[rt] = lmx1[ls]; if(lmx1[ls] == len(ls)) lmx1[rt] += lmx1[rs];
    rmx0[rt] = rmx0[rs]; if(rmx0[rs] == len(rs)) rmx0[rt] += rmx0[ls];
    rmx1[rt] = rmx1[rs]; if(rmx1[rs] == len(rs)) rmx1[rt] += rmx1[ls];
    sum[rt] = sum[ls] + sum[rs];
}

一定要先覆盖再翻转。个人建议写两个 push_up 函数。

点击查看代码
#include <bits/stdc++.h>
//#define int long long
#define rep(i, l, r) for(int i = l; i <= r; ++ i)
#define per(i, r, l) for(int i = r; i >= l; -- i)
using namespace std;
const int N = 8e5 + 10;
int n, m;
int a[N];
#define ls rt << 1
#define rs rt << 1 | 1
#define len(x) (rr[x] - ll[x] + 1) //这里需要打括号
int mx0[N], mx1[N], lmx0[N], lmx1[N], rmx0[N], rmx1[N], op[N], tag[N], ll[N], rr[N], sum[N];
void update(int rt)
{
    mx0[rt] = max({mx0[ls], mx0[rs], rmx0[ls] + lmx0[rs]});
    mx1[rt] = max({mx1[ls], mx1[rs], rmx1[ls] + lmx1[rs]});
    lmx0[rt] = lmx0[ls];
    if(lmx0[ls] == len(ls)) lmx0[rt] += lmx0[rs];
    lmx1[rt] = lmx1[ls];
    if(lmx1[ls] == len(ls)) lmx1[rt] += lmx1[rs];
    rmx0[rt] = rmx0[rs];
    if(rmx0[rs] == len(rs)) rmx0[rt] += rmx0[ls];
    rmx1[rt] = rmx1[rs];
    if(rmx1[rs] == len(rs)) rmx1[rt] += rmx1[ls];
    sum[rt] = sum[ls] + sum[rs];
}
void build(int rt, int l, int r)
{
    ll[rt] = l; rr[rt] = r;
    op[rt] = 0; tag[rt] = -1;
    if(l == r)
    {
        mx0[rt] = lmx0[rt] = rmx0[rt] = (a[l] ^ 1);
        sum[rt] = mx1[rt] = lmx1[rt] = rmx1[rt] = a[l];
        return ;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid); build(rs, mid + 1, r);
    update(rt);
}
void push_up_rev(int rt)
{
    sum[rt] = len(rt) - sum[rt];
    swap(mx1[rt], mx0[rt]);
    swap(lmx1[rt], lmx0[rt]);
    swap(rmx1[rt], rmx0[rt]);
    op[rt] ^= 1;
}
void push_up_cov(int rt, int v)
{
    op[rt] = 0;
    mx0[rt] = lmx0[rt] = rmx0[rt] = (v ^ 1) * len(rt);
    sum[rt] = mx1[rt] = lmx1[rt] = rmx1[rt] = v * len(rt);
    tag[rt] = v;
}
void push_down_cov(int rt)
{
    if(tag[rt] == -1) return;
    push_up_cov(ls, tag[rt]);
    push_up_cov(rs, tag[rt]);
    tag[rt] = -1;
}
void push_down_rev(int rt)
{
    if(!op[rt]) return;
    push_up_rev(ls);
    push_up_rev(rs);
    op[rt] = 0;
}
void change(int rt, int x, int y, int opt)
{
    int l = ll[rt], r = rr[rt];
    if(x <= l && r <= y)
    {
        if(opt <= 1) push_up_cov(rt, opt);
        else push_up_rev(rt);
        return;
    }
    push_down_cov(rt); push_down_rev(rt);
    int mid = (l + r) >> 1;
    if(mid >= x) change(ls, x, y, opt);
    if(mid < y) change(rs, x, y, opt);
    update(rt);
}
int ask1(int rt, int x, int y)
{
    int l = ll[rt], r = rr[rt];
    if(x <= l && r <= y) return sum[rt];
    push_down_cov(rt); push_down_rev(rt);
    int mid = (l + r) >> 1, ans = 0;
    if(mid >= x) ans += ask1(ls, x, y);
    if(mid < y) ans += ask1(rs, x, y);
    return ans;
}
int ans = -1, rmx;
void ask2(int rt, int x, int y)
{
    int l = ll[rt], r = rr[rt];
    if(x <= l && r <= y)
    {
        if(ans == -1)
        {
            ans = mx1[rt];
            rmx = rmx1[rt];
            return; 
        }
        ans = max({ans, mx1[rt], rmx + lmx1[rt]});
        if(rmx1[rt] == len(rt)) rmx += len(rt);
        else rmx = rmx1[rt];
        return;
    }
    push_down_cov(rt); push_down_rev(rt);
    int mid = (l + r) >> 1;
    if(mid >= x) ask2(ls, x, y);
    if(mid < y) ask2(rs, x, y);
}
main()
{
    cin >> n >> m;
    rep(i, 1, n) cin >> a[i];
    build(1, 1, n);
    rep(i, 1, m)
    {
        int opt, l, r;
        cin >> opt >> l >> r; ++ l; ++ r;
        if(opt <= 2) change(1, l, r, opt);
        else if(opt == 3) cout << ask1(1, l, r) << '\n';
        else
        {
            ans = -1; ask2(1, l, r);
            cout << ans << '\n';
        }
    }
    return 0;
}

AT practice2_j

单点修改,区间求 \(\max\),找到并输出最小的 \(j\) ,满足 \(X_i \leq j \leq N, V_i \leq A_j\) ,否则输出 \(n + 1\)

和这题差不多 Abl_e,还简单一点。

自己看代码。

点击查看代码
#include <bits/stdc++.h>
#define int long long
#define rep(i, l, r) for(int i = l; i <= r; ++ i)
#define per(i, r, l) for(int i = r; i >= l; -- i)
#define ls rt << 1 
#define rs rt << 1 | 1
using namespace std;
const int N = 8e5 + 10;
int ll[N], rr[N], mx[N], a[N];
int n, m;
int ans;
void update(int rt) {mx[rt] = max(mx[ls], mx[rs]);}
void build(int rt, int l, int r)
{
    ll[rt] = l; rr[rt] = r;
    if(l == r) {mx[rt] = a[l]; return;}
    int mid = (l + r) >> 1; build(ls, l, mid); build(rs, mid + 1, r); update(rt);
}
void change(int rt, int x, int v)
{
    int l = ll[rt], r = rr[rt];
    if(l == r) {mx[rt] = v; return;}
    int mid = (l + r) >> 1;
    if(mid >= x) change(ls, x, v);
    else change(rs, x, v);
    update(rt);
}
int ask(int rt, int x, int y)
{
    int l = ll[rt], r = rr[rt];
    if(x <= l && r <= y) return mx[rt];
    int mid = (l + r) >> 1, ret = 0;
    if(mid >= x) ret = max(ret, ask(ls, x, y));
    if(mid < y) ret = max(ret, ask(rs, x, y));
    return ret;
}
void ask1(int rt, int x, int y, int v)
{
    int l = ll[rt], r = rr[rt];
    if(ans != n + 1 || mx[rt] < v) return;
    if(l == r) {ans = l; return;}
    int mid = (l + r) >> 1;
    if(mid >= x) ask1(ls, x, y, v);
    if(mid < y) ask1(rs, x, y, v);
}
main()
{
    cin >> n >> m;
    rep(i, 1, n) cin >> a[i];
    build(1, 1, n);
    rep(i, 1, m)
    {
        int x, y, opt;
        cin >> opt >> x >> y;
        if(opt == 1) change(1, x, y);
        else if(opt == 2) printf("%lld\n", ask(1, x, y));
        else {
            ans = n + 1; ask1(1, x, n, y);
            printf("%lld\n", ans);
        }
    }
    return 0;
}
posted @ 2024-11-16 14:33  liukejie  阅读(5)  评论(0编辑  收藏  举报