P4344 SHOI2015 脑洞治疗仪

P4344 SHOI2015 脑洞治疗仪 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

其实是水紫。原题题面描述比较诡异,这里精炼一下。

维护一个 01 序列,支持以下操作:

  • 给定 \(l\)\(r\),将区间 \([l, r]\) 推平为 \(0\)
  • 给定 \(l\)\(r\)\(x\),将区间 \([l, r]\) 的前 \(x\)\(0\) 修改为 \(1\),如果 \(x\) 超过该区间中 \(0\) 的数量则直接将该区间推平为 \(1\)
  • 给定 \(l\)\(r\),查询 \([l, r]\)\(1\) 的数量;
  • 给定 \(l\)\(r\),查询 \([l, r]\) 中的纯 \(0\) 最大连续子序列长度。

一、三、四操作是线段树老生常谈的问题了,我们考虑维护以下信息:

  • sum 表示 \(1\) 的数量;
  • lmx 表示从左开始纯 \(0\) 最大连续子序列长度;
  • rmx 表示从右开始纯 \(0\) 最大连续子序列长度;
  • ans 表示纯 \(0\) 最大连续子序列长度;
  • len 表示区间长度。

那么对于 pushup 有:

t[p].sum = t[ls(p)].sum + t[rs(p)].sum;
t[p].lmx = max(t[ls(p)].lmx, (t[ls(p)].len + t[rs(p)].lmx) * (t[ls(p)].sum == 0));
t[p].rmx = max(t[rs(p)].rmx, (t[rs(p)].len + t[ls(p)].rmx) * (t[rs(p)].sum == 0));
t[p].ans = max({t[ls(p)].ans, t[rs(p)].ans, t[ls(p)].rmx + t[rs(p)].lmx});

这个转移可以自己尝试推导一下,如果做过线段树维护区间最大连续子序列和,应该不难理解。

推平的懒标记很好写。记录 tag。为 \(0\) 表示无标记。为 \(1\) 表示推平成 \(0\),为 \(2\) 表示推平成 \(1\)

推平状态下维护的那四个变量(区间长度不算)也是显然的。然后就没了。

那么怎么做操作二?

其实很简单,二分一下中间点 \(k\),询问 \([l, k]\)\(0\) 的数量,就可以找到使得 \([l, k]\)\(0\) 的数量为 \(x\) 的那个 \(k\)。然后推平 \([l, k]\) 就可以了。

这样以来操作二复杂度是 \(\log^2 n\) 的,可以通过本题。但事实上可以单 \(\log\)

考虑在线段树上搞这个操作二,有两种情况:

  • 所有 \(1\) 都填到了左半区间;
  • \(1\) 把左半区间填满(推平)了,然后剩下的 \(1\) 右填到了右半区间。

复杂度 \(\Theta(n \log n + m \log n)\)

/*
 * @Author: crab-in-the-northeast 
 * @Date: 2023-01-04 05:00:10 
 * @Last Modified by: crab-in-the-northeast
 * @Last Modified time: 2023-01-04 05:42:04
 */
#include <bits/stdc++.h>
inline int read() {
    int x = 0;
    bool f = true;
    char ch = getchar();
    for (; !isdigit(ch); ch = getchar())
        if (ch == '-')
            f = false;
    for (; isdigit(ch); ch = getchar())
        x = (x << 1) + (x << 3) + ch - '0';
    return f ? x : (~(x - 1));
}
inline int ls(int p) {
    return p << 1;
}
inline int rs(int p) {
    return p << 1 | 1;
}

const int maxn = (int)2e5 + 5;

struct node {
    int len, sum, lmx, rmx, ans;
} t[maxn << 2];

int laz[maxn << 2];

node con(node lef, node rgt) {
    return (node) {
        lef.len + rgt.len,
        lef.sum + rgt.sum,
        std :: max(lef.lmx, (lef.len + rgt.lmx) * (lef.sum == 0)),
        std :: max(rgt.rmx, (rgt.len + lef.rmx) * (rgt.sum == 0)),
        std :: max({lef.ans, rgt.ans, lef.rmx + rgt.lmx})
    };
}

void up(int p) {
    t[p] = con(t[ls(p)], t[rs(p)]);
}

void build(int p, int l, int r) {
    if (l == r) {
        t[p] = (node) {
            1, 1, 0, 0, 0
        };
        return ;
    }
    int mid = (l + r) >> 1;
    build(ls(p), l, mid);
    build(rs(p), mid + 1, r);
    up(p);
}

void tag(int p, int op) {
    if (op == 1) {
        t[p] = (node) {
            t[p].len, 0, t[p].len, t[p].len, t[p].len
        };
        laz[p] = 1;
    } else {
        t[p] = (node) {
            t[p].len, t[p].len, 0, 0, 0
        };
        laz[p] = 2;
    }
}

void down(int p) {
    if (!laz[p])
        return ;
    tag(ls(p), laz[p]);
    tag(rs(p), laz[p]);
    laz[p] = 0;
}

node query(int p, int l, int r, int L, int R) {
    if (l == L && r == R)
        return t[p];
    down(p);
    int mid = (l + r) >> 1;
    if (R <= mid)
        return query(ls(p), l, mid, L, R);
    else if (L > mid)
        return query(rs(p), mid + 1, r, L, R);
    else
        return con(
            query(ls(p), l, mid, L, mid),
            query(rs(p), mid + 1, r, mid + 1, R)
            );
}

void clear(int p, int l, int r, int L, int R) {
    if (l == L && R == r) {
        tag(p, 1);
        return ;
    }
    down(p);
    int mid = (l + r) >> 1;
    if (R <= mid)
        clear(ls(p), l, mid, L, R);
    else if (L > mid)
        clear(rs(p), mid + 1, r, L, R);
    else {
        clear(ls(p), l, mid, L, mid);
        clear(rs(p), mid + 1, r, mid + 1, R);
    }
    up(p);
}

int sol(int p, int l, int r, int L, int R, int x) {
    if (!x)
        return 0;
    int tot = t[p].len - t[p].sum;
    if (l == L && R == r && tot <= x) {
        tag(p, 2);
        return x - tot;
    }
    down(p);
    int ans = 0, mid = (l + r) >> 1;
    if (R <= mid)
        ans = sol(ls(p), l, mid, L, R, x);
    else if (L > mid)
        ans = sol(rs(p), mid + 1, r, L, R, x);
    else
        ans = sol(rs(p), mid + 1, r, mid + 1, R, 
        sol(ls(p), l, mid, L, mid, x));
    up(p);
    return ans;
}

int main() {
    int n = read(), m = read();
    build(1, 1, n);
    while (m--) {
        int op = read();
        if (op == 0) {
            int l = read(), r = read();
            clear(1, 1, n, l, r);
        } else if (op == 1) {
            int l1 = read(), r1 = read(), l2 = read(), r2 = read();
            int x = query(1, 1, n, l1, r1).sum;
            clear(1, 1, n, l1, r1);
            sol(1, 1, n, l2, r2, x);
        } else {
            int l = read(), r = read();
            printf("%d\n", query(1, 1, n, l, r).ans);
        }
    }
    return 0;
}
posted @ 2023-01-04 05:44  dbxxx  阅读(35)  评论(0编辑  收藏  举报