【序列操作I】线段树
题目描述
Lxhgww 最近收到了一个 01 序列,序列里面包含了 n(1≤n≤105)个数,这些书要么是 0,要么是 1,现在对这个序列有五种变换操作和询问操作:
1. 0 a b ,把[a,b]区间内所有数全部变成 0。
2. 1 a b ,把[a,b]区间内所有数全部变成 1。
3. 2 a b ,把[a,b]区间内所有数全部取反,也就是说把所有的 0 变成 1,把所有的 1 变成 0。
4. 3 a b ,询问[a,b]区间内总共有多少个 1。
5. 4 a b ,询问[a,b]区间内最多有多少个连续的 1。
对于每一种询问操作,Lxhgww 都需要给出回答,聪明的程序员们,你们能帮助他吗?
输入格式
输入数据第一行包括 2 个数,n 和 m(1≤m≤105)分别表示序列的长度和操作数目。
第二行包括 n 个数,表示序列的初始状态.
接下来 m 行,每行 3 个数,op,a,b(0≤op≤4,0≤a≤b<n),表示对于区间[a,b]执行标号为 op 的操作。
输出格式
对于每次询问,输出单独的一行表示答案。
样例数据 1
输入
10 10
0 0 0 1 1 0 1 0 1 1
1 0 2
3 0 5
2 2 2
4 0 4
0 3 6
2 3 7
4 2 8
1 0 5
0 5 6
3 3 9
输出
5
2
6
5
题目分析
线段树裸题,关于区间最大连续的问题,都是维护左端最长连续,右端最长连续,和总的最长连续,更新即可。
比较坑的是下标的下放顺序:无论是否有反转标记都可以直接覆盖,把反转标志置为false。但若是有覆盖标记,就必须先进行覆盖标记的下传,再进行反转。
code
#include<iostream> #include<cstdlib> #include<cstdio> #include<cstring> #include<string> #include<algorithm> #include<cmath> using namespace std; const int N = 1e5 + 5; int n, m, data[N]; struct node{ int len, cnt, tag; bool rev; int lx0, rx0, lx1, rx1, mx0, mx1; node():tag(-1){} }; inline void wr(int); namespace SegTree{ node tr[N << 2]; inline void upt(int k){ tr[k].cnt = tr[k << 1].cnt + tr[k << 1 | 1].cnt; tr[k].lx0 = tr[k << 1].lx0, tr[k].lx1 = tr[k << 1].lx1; tr[k].rx0 = tr[k << 1 | 1].rx0, tr[k].rx1 = tr[k << 1 | 1].rx1; if(tr[k << 1].cnt == tr[k << 1].len) tr[k].lx1 += tr[k << 1 | 1].lx1; if(tr[k << 1].cnt == 0) tr[k].lx0 += tr[k << 1 | 1].lx0; if(tr[k << 1 | 1].cnt == tr[k << 1 | 1].len) tr[k].rx1 += tr[k << 1].rx1; if(tr[k << 1 | 1].cnt == 0) tr[k].rx0 += tr[k << 1].rx0; tr[k].mx0 = max(tr[k << 1].mx0, tr[k << 1 | 1].mx0); tr[k].mx0 = max(tr[k].mx0, tr[k << 1].rx0 + tr[k << 1 | 1].lx0); tr[k].mx1 = max(tr[k << 1].mx1, tr[k << 1 | 1].mx1); tr[k].mx1 = max(tr[k].mx1, tr[k << 1].rx1 + tr[k << 1 | 1].lx1); } inline void cover(int , int); inline void Rev(int k){ if(tr[k].tag != -1){ if(tr[k].len > 1) cover(k << 1, tr[k].tag), cover(k << 1 | 1, tr[k].tag); tr[k].tag = -1; } tr[k].cnt = tr[k].len - tr[k].cnt; swap(tr[k].lx1, tr[k].lx0); swap(tr[k].rx1, tr[k].rx0); swap(tr[k].mx0, tr[k].mx1); tr[k].rev ^= 1; } inline void cover(int k, int v){ tr[k].rev = 0; tr[k].cnt = tr[k].lx1 = tr[k].rx1 = tr[k].mx1 = (v == 1) * tr[k].len; tr[k].lx0 = tr[k].rx0 = tr[k].mx0 = (v == 0) * tr[k].len; tr[k].tag = v; } inline void pushdown(int k){ if(tr[k].tag != -1){ if(tr[k].len > 1) cover(k << 1, tr[k].tag), cover(k << 1 | 1, tr[k].tag); tr[k].tag = -1; } if(tr[k].rev){ tr[k].rev = 0; if(tr[k].len > 1) Rev(k << 1), Rev(k << 1 | 1); } } inline int queryCnt(int k, int l, int r, int x, int y){ pushdown(k); if(x <= l && r <= y) return tr[k].cnt; int mid = l + r >> 1, ret = 0; if(x <= mid) ret += queryCnt(k << 1, l, mid, x, y); if(y > mid) ret += queryCnt(k << 1 | 1, mid + 1, r, x, y); return ret; } inline node queryMx(int k, int l, int r, int x, int y){ pushdown(k); if(l == x && r == y) return tr[k]; int mid = l + r >> 1; if(y <= mid) return queryMx(k << 1, l, mid, x, y); else if(x > mid) return queryMx(k << 1 | 1, mid + 1, r, x, y); else{ node ret1 = queryMx(k << 1, l, mid, x, mid); node ret2 = queryMx(k << 1 | 1, mid + 1, r, mid + 1, y); node ret; ret.lx1 = ret1.lx1; ret.rx1 = ret2.rx1; if(ret1.cnt == ret1.len) ret.lx1 += ret2.lx1; if(ret2.cnt == ret2.len) ret.rx1 += ret1.rx1; ret.mx1 = max(ret1.mx1,ret2.mx1); ret.mx1 = max(ret.mx1, ret1.rx1 + ret2.lx1); return ret; } } inline void build(int k, int l, int r){ tr[k].len = r - l + 1; if(l == r){ tr[k].lx1 = tr[k].rx1 = tr[k].mx1 = tr[k].cnt = (data[l] == 1); tr[k].lx0 = tr[k].rx0 = tr[k].mx0 = (data[l] == 0); tr[k].rev = 0; tr[k].tag = -1; return; } int mid = l + r >> 1; build(k << 1, l, mid); build(k << 1 | 1, mid + 1, r); upt(k); } inline void modify(int k, int l, int r, int x, int y, int opt){ pushdown(k); if(x <= l && r <= y){ switch(opt){ case 0: cover(k, 0); break; case 1: cover(k, 1); break; case 2: Rev(k); break; } return; } int mid = l + r >> 1; if(x <= mid) modify(k << 1, l, mid, x, y, opt); if(y > mid) modify(k << 1 | 1, mid + 1, r, x, y, opt); upt(k); } }using namespace SegTree; inline int read(){ int i = 0, f = 1; char ch = getchar(); for(; (ch < '0' || ch > '9') && ch != '-'; ch = getchar()); if(ch == '-') f = -1, ch = getchar(); for(; ch >= '0' && ch <= '9'; ch = getchar()) i = (i << 3) + (i << 1) + (ch - '0'); return i * f; } inline void wr(int x){ if(x < 0) putchar('-'), x = -x; if(x > 9) wr(x / 10); putchar(x % 10 + '0'); } int main(){ n = read(); m = read(); for(int i = 1; i <= n; i++) data[i] = read(); build(1, 1, n); for(int i = 1; i <= m; i++){ int opt = read(); int a = read() + 1, b = read() + 1; if(opt == 0 || opt == 1 || opt == 2) modify(1, 1, n, a, b, opt); else if(opt == 3) wr(queryCnt(1, 1, n, a, b)), putchar('\n'); else wr((queryMx(1, 1, n, a, b)).mx1), putchar('\n'); } return 0; }