SCOI 2010 序列操作 (线段树)

听说有人调了半年,我也来做一做

这是一道非常毒瘤的线段树模板题

它需要支持:

  • 区间赋值
  • 区间反转
  • 区间和
  • 区间最多有多少个连续的 1

线段树的难点基本上就是:

  • 标记的定义
  • 标记的顺序
  • 标记的上传,下传

\(3\) 个都是基本操作,我们考虑如何搞定第四个

首先定义变量:

\(sum\) 表示区间和
\(max~[0/1]\) 表示区间连续 \(0/1\) 的最大长度
\(tag\) 表示区间整体赋为某一个数
\(rev\) 表示区间反转标记

怎么通过左右子节点维护 \(max\) ?

继续定义;

\(lmax~[0/1]\) 表示包含左端点 \(0/1\) 连续的最大长度
\(rmax~[0/1]\) 表示包含右端点 \(0/1\) 连续的最大长度

线段树的

首先考虑 \(lmax(x)\) 怎么由 \(lmax(ls)\)\(lmax(rs)\) 得到

  • 首先 \(lmax(x) = lmax(ls)\)
  • 如果 \(lmax(ls)\) 全部满足,那么我们可以加入 \(lmax(rs)\)

\(rmax(x)\) 的处理方式同理

\(max(x)\) 呢?

要么两段合并,要么就是两段分开的答案

上面讲了上传,考虑下传

显然我们知道区间赋值操作的优先级大于区间反转

感觉下传方法没什么可说的

那么我们先处理区间赋值

然后对于区间反转,需要考虑如果叶子结点有赋值操作怎么办

这个问题纠结了一小会(我可能呆了

经过一段思考,我们发现赋值再反转就等于赋反值

所以我们不如直接把 \(tag\) 取反

然后也就是比较套路的东西了

可以看一下代码

之后讲一下写代码时遇到的一下问题

  • \(\#define\) 时把 \(mx(x, i)\) 定义成了 \(tr[x].mx[2]\),花了好久才发现
  • 发现操作 \(4\) 原来的线段树写法不再适用,于是函数重写了一下
  • 之后过了样例,交上去获得了 10 分的成绩,把 标记全输出来,发现怎么不对啊,然后想起来这种线段树写法要把 \(pushdown\) 写在前面,再交了一遍获得了 90 分
  • 把数组开大(其实没必要),把 \(O(2)\) 去掉就 \(AC\) 了,额,\(O(2)\) 害人不浅啊。。。
#include <map>
#include <set>
#include <ctime>
#include <queue>
#include <stack>
#include <cmath>
#include <vector>
#include <bitset>
#include <cstdio>
#include <cctype>
#include <string>
#include <numeric>
#include <cstring>
#include <cassert>
#include <climits>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std ;
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define loop(s, v, it) for (s::iterator it = v.begin(); it != v.end(); it++)
#define cont(i, x) for (int i = head[x]; i; i = e[i].nxt)
#define clr(a) memset(a, 0, sizeof(a))
#define ass(a, sum) memset(a, sum, sizeof(a))
#define lowbit(x) (x & -x)
#define all(x) x.begin(), x.end()
#define ub upper_bound
#define lb lower_bound
#define pq priority_queue
#define mp make_pair
#define pb push_back
#define pof pop_front
#define pob pop_back
#define fi first
#define se second
#define iv inline void
#define enter cout << endl
#define siz(x) ((int)x.size())
#define file(x) freopen(x".in", "r", stdin),freopen(x".out", "w", stdout)
typedef long long ll ;
typedef unsigned long long ull ;
typedef pair <int, int> pii ;
typedef vector <int> vi ;
typedef vector <pii> vii ;
typedef queue <int> qi ;
typedef queue <pii> qii ;
typedef set <int> si ;
typedef map <int, int> mii ;
typedef map <string, int> msi ;
const int N = 500010 ;
const int INF = 0x3f3f3f3f ;
const int iinf = 1 << 30 ;
const ll linf = 2e18 ;
const int MOD = 1000000007 ;
const double eps = 1e-7 ;
void douout(double x){ printf("%lf\n", x + 0.0000000001) ; }
template <class T> void print(T a) { cout << a << endl ; exit(0) ; }
template <class T> void chmin(T &a, T b) { if (a > b) a = b ; }
template <class T> void chmax(T &a, T b) { if (a < b) a = b ; }
template <class T> void upd(T &a, T b) { (a += b) %= MOD ; }
template <class T> void mul(T &a, T b) { a = (ll) a * b % MOD ; }

/*
sum : 区间和
max 区间 0/1 最长
lmax 左端点 0/1 最长
rmax 右端点 0/1 最长
lazy 区间是否全部赋值
rev 区间取反
*/

int n, m ;
int a[N] ;

struct Segtr {
    int l, r, sum, mx[2], lmax[2], rmax[2], tag, rev ;
    #define ls(x) (x << 1)
    #define rs(x) ((x << 1) | 1)
    #define l(x) tr[x].l
    #define r(x) tr[x].r
    #define sz(x) (tr[x].r - tr[x].l + 1)
    #define sum(x) tr[x].sum
    #define mx(x, i) tr[x].mx[i]
    #define lmax(x, i) tr[x].lmax[i]
    #define rmax(x, i) tr[x].rmax[i]
    #define tag(x) tr[x].tag
    #define rev(x) tr[x].rev
} tr[N << 2] ;

void pushup(int x) {
    sum(x) = sum(ls(x)) + sum(rs(x)) ; // 区间累加即可
    rep(i, 0, 1) { // 对 0 / 1 分开考虑
        // 左节点如果信息全部符合, 可以把右边的也加进来
        lmax(x, i) = lmax(ls(x), i) ;
        if (i == 0 && sum(ls(x)) == 0) lmax(x, i) += lmax(rs(x), i) ;
        if (i == 1 && sum(ls(x)) == sz(ls(x))) lmax(x, i) += lmax(rs(x), i) ;
        // 同理
        rmax(x, i) = rmax(rs(x), i) ;
        if (i == 0 && sum(rs(x)) == 0) rmax(x, i) += rmax(ls(x), i) ;
        if (i == 1 && sum(rs(x)) == sz(rs(x))) rmax(x, i) += rmax(ls(x), i) ;

        mx(x, i) = rmax(ls(x), i) + lmax(rs(x), i) ;
        chmax(mx(x, i), mx(ls(x), i)) ;
        chmax(mx(x, i), mx(rs(x), i)) ;
    }
}

void pushdown(int x) {
    if (tag(x) != -1) { // 显然赋值的优先级最高
        rev(x) = 0 ; // 没用了
        int v = tag(x) ;
        rev(ls(x)) = rev(rs(x)) = 0 ;
        tag(ls(x)) = tag(rs(x)) = v ;
        sum(ls(x)) = sz(ls(x)) * v ;
        sum(rs(x)) = sz(rs(x)) * v ;
        mx(ls(x), v) = sz(ls(x)) ;
        mx(rs(x), v) = sz(rs(x)) ;
        mx(ls(x), v ^ 1) = mx(rs(x), v ^ 1) = 0 ;
        lmax(ls(x), v) = sz(ls(x)) ;
        lmax(rs(x), v) = sz(rs(x)) ;
        lmax(ls(x), v ^ 1) = lmax(rs(x), v ^ 1) = 0 ;
        rmax(ls(x), v) = sz(ls(x)) ; 
        rmax(rs(x), v) = sz(rs(x)) ;
        rmax(ls(x), v ^ 1) = rmax(rs(x), v ^ 1) = 0 ;
        tag(x) = -1 ;
    }
    if (rev(x)) {
        sum(ls(x)) = sz(ls(x)) - sum(ls(x)) ;
        sum(rs(x)) = sz(rs(x)) - sum(rs(x)) ;
        if (tag(ls(x)) != -1) tag(ls(x)) ^= 1 ;// 考虑到优先级,其实先赋再取反就等于赋反的
        else rev(ls(x)) ^= 1 ;
        if (tag(rs(x)) != -1) tag(rs(x)) ^= 1 ;
        else rev(rs(x)) ^= 1 ;
        swap(mx(ls(x), 0), mx(ls(x), 1)) ;
        swap(mx(rs(x), 0), mx(rs(x), 1)) ;
        swap(lmax(ls(x), 0), lmax(ls(x), 1)) ;
        swap(lmax(rs(x), 0), lmax(rs(x), 1)) ;
        swap(rmax(ls(x), 0), rmax(ls(x), 1)) ;
        swap(rmax(rs(x), 0), rmax(rs(x), 1)) ; 
        rev(x) = 0 ;
    }
}

void build(int x, int l, int r) {
    l(x) = l, r(x) = r, tag(x) = -1 ;
    if (l == r) {
        sum(x) = a[l] ;
        mx(x, 0) = lmax(x, 0) = rmax(x, 0) = a[l] == 0 ;
        mx(x, 1) = lmax(x, 1) = rmax(x, 1) = a[l] == 1 ;
        return ;
    }
    int mid = (l + r) >> 1 ;
    build(ls(x), l, mid) ;
    build(rs(x), mid + 1, r) ;
    pushup(x) ;
}

void modify(int x, int op, int l, int r) {
    pushdown(x) ;
    if (l == l(x) && r(x) == r) {
        if (op <= 1) { // 修改为 0 / 1
            tag(x) = op ;
            sum(x) = sz(x) * op ;
            mx(x, op) = lmax(x, op) = rmax(x, op) = sz(x) ;
            mx(x, op ^ 1) = lmax(x, op ^ 1) = rmax(x, op ^ 1) = 0 ;
        } else { // 区间取反
            sum(x) = sz(x) - sum(x) ;
            rev(x) ^= 1 ;
            swap(mx(x, 0), mx(x, 1)) ;
            swap(lmax(x, 0), lmax(x, 1)) ;
            swap(rmax(x, 0), rmax(x, 1)) ;
        }
        return ;
    }
    int mid = (l(x) + r(x)) >> 1 ;
    if (l > mid) modify(rs(x), op, l, r) ;
    else if (mid >= r) modify(ls(x), op, l, r) ;
    else modify(ls(x), op, l, mid), modify(rs(x), op, mid + 1, r) ;
    pushup(x) ;
}

int query(int x, int l, int r) {
    pushdown(x) ;
    if (l == l(x) && r(x) == r) return sum(x) ;
    int mid = (l(x) + r(x)) >> 1 ;
    if (l > mid) return query(rs(x), l, r) ;
    else if (mid >= r) return query(ls(x), l, r) ;
    else return query(ls(x), l, mid) + query(rs(x), mid + 1, r) ;
}

Segtr Query(int x, int l, int r) {    
    pushdown(x) ;
    if (l == l(x) && r(x) == r) return tr[x] ;
    int mid = (l(x) + r(x)) >> 1 ;
    if (l > mid) return Query(rs(x), l, r) ;
    else if (mid >= r) return Query(ls(x), l, r) ;
    else {
        Segtr ans, L, R ;
        L = Query(ls(x), l, mid) ;
        R = Query(rs(x), mid + 1, r) ;
        ans.sum = L.sum + R.sum ;
        rep(i, 0, 1) {
            ans.lmax[i] = L.lmax[i] ;
            if (i == 0 && L.sum == 0) ans.lmax[i] += R.lmax[i] ;
            if (i == 1 && L.sum == L.r - L.l + 1) ans.lmax[i] += R.lmax[i] ;

            ans.rmax[i] = R.rmax[i] ;
            if (i == 0 && R.sum == 0) ans.rmax[i] += L.rmax[i] ;
            if (i == 1 && R.sum == R.r - R.l + 1) ans.rmax[i] += L.rmax[i] ;

            ans.mx[i] = L.rmax[i] + R.lmax[i] ;
            chmax(ans.mx[i], L.mx[i]) ;
            chmax(ans.mx[i], R.mx[i]) ;
        }
        return ans ;
    }
}

signed main(){
    scanf("%d%d", &n, &m) ;
    rep(i, 1, n) scanf("%d", &a[i]) ;
    build(1, 1, n) ;
    while (m--) {
        int op, l, r ; scanf("%d%d%d", &op, &l, &r) ;
        l++ ; r++ ;
        if (op <= 2) { // 前三个均为修改
            modify(1, op, l, r) ;
        } 
        if (op == 3) {
            printf("%d\n", query(1, l, r)) ;
        }
        if (op == 4) {
            printf("%d\n", Query(1, l, r).mx[1]) ;
        }
    }
	return 0 ;
}

/*
写代码时请注意:
	1.ll?数组大小,边界?数据范围?
	2.精度?
	3.特判?
	4.至少做一些
思考提醒:
	1.最大值最小->二分?
	2.可以贪心么?不行dp可以么
	3.可以优化么
	4.维护区间用什么数据结构?
	5.统计方案是用dp?模了么?
	6.逆向思维?
*/



posted @ 2019-04-10 22:00  harryhqg  阅读(182)  评论(0编辑  收藏  举报