SCOI 2010 序列操作 (线段树)
听说有人调了半年,我也来做一做
这是一道非常毒瘤的线段树模板题
它需要支持:
- 区间赋值
- 区间反转
- 区间和
- 区间最多有多少个连续的 1
线段树的难点基本上就是:
- 标记的定义
- 标记的顺序
- 标记的上传,下传
前 \(3\) 个都是基本操作,我们考虑如何搞定第四个
首先定义变量:
\(sum\) 表示区间和
\(max~[0/1]\) 表示区间连续 \(0/1\) 的最大长度
\(tag\) 表示区间整体赋为某一个数
\(rev\) 表示区间反转标记
怎么通过左右子节点维护 \(max\) ?
继续定义;
\(lmax~[0/1]\) 表示包含左端点 \(0/1\) 连续的最大长度
\(rmax~[0/1]\) 表示包含右端点 \(0/1\) 连续的最大长度
线段树的
首先考虑 \(lmax(x)\) 怎么由 \(lmax(ls)\) 和 \(lmax(rs)\) 得到
- 首先 \(lmax(x) = lmax(ls)\)
- 如果 \(lmax(ls)\) 全部满足,那么我们可以加入 \(lmax(rs)\)
\(rmax(x)\) 的处理方式同理
\(max(x)\) 呢?
要么两段合并,要么就是两段分开的答案
上面讲了上传,考虑下传
显然我们知道区间赋值操作的优先级大于区间反转
感觉下传方法没什么可说的
那么我们先处理区间赋值
然后对于区间反转,需要考虑如果叶子结点有赋值操作怎么办
这个问题纠结了一小会(我可能呆了
经过一段思考,我们发现赋值再反转就等于赋反值
所以我们不如直接把 \(tag\) 取反
然后也就是比较套路的东西了
可以看一下代码
之后讲一下写代码时遇到的一下问题
- \(\#define\) 时把 \(mx(x, i)\) 定义成了 \(tr[x].mx[2]\),花了好久才发现
- 发现操作 \(4\) 原来的线段树写法不再适用,于是函数重写了一下
- 之后过了样例,交上去获得了 10 分的成绩,把 标记全输出来,发现怎么不对啊,然后想起来这种线段树写法要把 \(pushdown\) 写在前面,再交了一遍获得了 90 分
- 把数组开大(其实没必要),把 \(O(2)\) 去掉就 \(AC\) 了,额,\(O(2)\) 害人不浅啊。。。
#include <map>
#include <set>
#include <ctime>
#include <queue>
#include <stack>
#include <cmath>
#include <vector>
#include <bitset>
#include <cstdio>
#include <cctype>
#include <string>
#include <numeric>
#include <cstring>
#include <cassert>
#include <climits>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std ;
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define loop(s, v, it) for (s::iterator it = v.begin(); it != v.end(); it++)
#define cont(i, x) for (int i = head[x]; i; i = e[i].nxt)
#define clr(a) memset(a, 0, sizeof(a))
#define ass(a, sum) memset(a, sum, sizeof(a))
#define lowbit(x) (x & -x)
#define all(x) x.begin(), x.end()
#define ub upper_bound
#define lb lower_bound
#define pq priority_queue
#define mp make_pair
#define pb push_back
#define pof pop_front
#define pob pop_back
#define fi first
#define se second
#define iv inline void
#define enter cout << endl
#define siz(x) ((int)x.size())
#define file(x) freopen(x".in", "r", stdin),freopen(x".out", "w", stdout)
typedef long long ll ;
typedef unsigned long long ull ;
typedef pair <int, int> pii ;
typedef vector <int> vi ;
typedef vector <pii> vii ;
typedef queue <int> qi ;
typedef queue <pii> qii ;
typedef set <int> si ;
typedef map <int, int> mii ;
typedef map <string, int> msi ;
const int N = 500010 ;
const int INF = 0x3f3f3f3f ;
const int iinf = 1 << 30 ;
const ll linf = 2e18 ;
const int MOD = 1000000007 ;
const double eps = 1e-7 ;
void douout(double x){ printf("%lf\n", x + 0.0000000001) ; }
template <class T> void print(T a) { cout << a << endl ; exit(0) ; }
template <class T> void chmin(T &a, T b) { if (a > b) a = b ; }
template <class T> void chmax(T &a, T b) { if (a < b) a = b ; }
template <class T> void upd(T &a, T b) { (a += b) %= MOD ; }
template <class T> void mul(T &a, T b) { a = (ll) a * b % MOD ; }
/*
sum : 区间和
max 区间 0/1 最长
lmax 左端点 0/1 最长
rmax 右端点 0/1 最长
lazy 区间是否全部赋值
rev 区间取反
*/
int n, m ;
int a[N] ;
struct Segtr {
int l, r, sum, mx[2], lmax[2], rmax[2], tag, rev ;
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
#define l(x) tr[x].l
#define r(x) tr[x].r
#define sz(x) (tr[x].r - tr[x].l + 1)
#define sum(x) tr[x].sum
#define mx(x, i) tr[x].mx[i]
#define lmax(x, i) tr[x].lmax[i]
#define rmax(x, i) tr[x].rmax[i]
#define tag(x) tr[x].tag
#define rev(x) tr[x].rev
} tr[N << 2] ;
void pushup(int x) {
sum(x) = sum(ls(x)) + sum(rs(x)) ; // 区间累加即可
rep(i, 0, 1) { // 对 0 / 1 分开考虑
// 左节点如果信息全部符合, 可以把右边的也加进来
lmax(x, i) = lmax(ls(x), i) ;
if (i == 0 && sum(ls(x)) == 0) lmax(x, i) += lmax(rs(x), i) ;
if (i == 1 && sum(ls(x)) == sz(ls(x))) lmax(x, i) += lmax(rs(x), i) ;
// 同理
rmax(x, i) = rmax(rs(x), i) ;
if (i == 0 && sum(rs(x)) == 0) rmax(x, i) += rmax(ls(x), i) ;
if (i == 1 && sum(rs(x)) == sz(rs(x))) rmax(x, i) += rmax(ls(x), i) ;
mx(x, i) = rmax(ls(x), i) + lmax(rs(x), i) ;
chmax(mx(x, i), mx(ls(x), i)) ;
chmax(mx(x, i), mx(rs(x), i)) ;
}
}
void pushdown(int x) {
if (tag(x) != -1) { // 显然赋值的优先级最高
rev(x) = 0 ; // 没用了
int v = tag(x) ;
rev(ls(x)) = rev(rs(x)) = 0 ;
tag(ls(x)) = tag(rs(x)) = v ;
sum(ls(x)) = sz(ls(x)) * v ;
sum(rs(x)) = sz(rs(x)) * v ;
mx(ls(x), v) = sz(ls(x)) ;
mx(rs(x), v) = sz(rs(x)) ;
mx(ls(x), v ^ 1) = mx(rs(x), v ^ 1) = 0 ;
lmax(ls(x), v) = sz(ls(x)) ;
lmax(rs(x), v) = sz(rs(x)) ;
lmax(ls(x), v ^ 1) = lmax(rs(x), v ^ 1) = 0 ;
rmax(ls(x), v) = sz(ls(x)) ;
rmax(rs(x), v) = sz(rs(x)) ;
rmax(ls(x), v ^ 1) = rmax(rs(x), v ^ 1) = 0 ;
tag(x) = -1 ;
}
if (rev(x)) {
sum(ls(x)) = sz(ls(x)) - sum(ls(x)) ;
sum(rs(x)) = sz(rs(x)) - sum(rs(x)) ;
if (tag(ls(x)) != -1) tag(ls(x)) ^= 1 ;// 考虑到优先级,其实先赋再取反就等于赋反的
else rev(ls(x)) ^= 1 ;
if (tag(rs(x)) != -1) tag(rs(x)) ^= 1 ;
else rev(rs(x)) ^= 1 ;
swap(mx(ls(x), 0), mx(ls(x), 1)) ;
swap(mx(rs(x), 0), mx(rs(x), 1)) ;
swap(lmax(ls(x), 0), lmax(ls(x), 1)) ;
swap(lmax(rs(x), 0), lmax(rs(x), 1)) ;
swap(rmax(ls(x), 0), rmax(ls(x), 1)) ;
swap(rmax(rs(x), 0), rmax(rs(x), 1)) ;
rev(x) = 0 ;
}
}
void build(int x, int l, int r) {
l(x) = l, r(x) = r, tag(x) = -1 ;
if (l == r) {
sum(x) = a[l] ;
mx(x, 0) = lmax(x, 0) = rmax(x, 0) = a[l] == 0 ;
mx(x, 1) = lmax(x, 1) = rmax(x, 1) = a[l] == 1 ;
return ;
}
int mid = (l + r) >> 1 ;
build(ls(x), l, mid) ;
build(rs(x), mid + 1, r) ;
pushup(x) ;
}
void modify(int x, int op, int l, int r) {
pushdown(x) ;
if (l == l(x) && r(x) == r) {
if (op <= 1) { // 修改为 0 / 1
tag(x) = op ;
sum(x) = sz(x) * op ;
mx(x, op) = lmax(x, op) = rmax(x, op) = sz(x) ;
mx(x, op ^ 1) = lmax(x, op ^ 1) = rmax(x, op ^ 1) = 0 ;
} else { // 区间取反
sum(x) = sz(x) - sum(x) ;
rev(x) ^= 1 ;
swap(mx(x, 0), mx(x, 1)) ;
swap(lmax(x, 0), lmax(x, 1)) ;
swap(rmax(x, 0), rmax(x, 1)) ;
}
return ;
}
int mid = (l(x) + r(x)) >> 1 ;
if (l > mid) modify(rs(x), op, l, r) ;
else if (mid >= r) modify(ls(x), op, l, r) ;
else modify(ls(x), op, l, mid), modify(rs(x), op, mid + 1, r) ;
pushup(x) ;
}
int query(int x, int l, int r) {
pushdown(x) ;
if (l == l(x) && r(x) == r) return sum(x) ;
int mid = (l(x) + r(x)) >> 1 ;
if (l > mid) return query(rs(x), l, r) ;
else if (mid >= r) return query(ls(x), l, r) ;
else return query(ls(x), l, mid) + query(rs(x), mid + 1, r) ;
}
Segtr Query(int x, int l, int r) {
pushdown(x) ;
if (l == l(x) && r(x) == r) return tr[x] ;
int mid = (l(x) + r(x)) >> 1 ;
if (l > mid) return Query(rs(x), l, r) ;
else if (mid >= r) return Query(ls(x), l, r) ;
else {
Segtr ans, L, R ;
L = Query(ls(x), l, mid) ;
R = Query(rs(x), mid + 1, r) ;
ans.sum = L.sum + R.sum ;
rep(i, 0, 1) {
ans.lmax[i] = L.lmax[i] ;
if (i == 0 && L.sum == 0) ans.lmax[i] += R.lmax[i] ;
if (i == 1 && L.sum == L.r - L.l + 1) ans.lmax[i] += R.lmax[i] ;
ans.rmax[i] = R.rmax[i] ;
if (i == 0 && R.sum == 0) ans.rmax[i] += L.rmax[i] ;
if (i == 1 && R.sum == R.r - R.l + 1) ans.rmax[i] += L.rmax[i] ;
ans.mx[i] = L.rmax[i] + R.lmax[i] ;
chmax(ans.mx[i], L.mx[i]) ;
chmax(ans.mx[i], R.mx[i]) ;
}
return ans ;
}
}
signed main(){
scanf("%d%d", &n, &m) ;
rep(i, 1, n) scanf("%d", &a[i]) ;
build(1, 1, n) ;
while (m--) {
int op, l, r ; scanf("%d%d%d", &op, &l, &r) ;
l++ ; r++ ;
if (op <= 2) { // 前三个均为修改
modify(1, op, l, r) ;
}
if (op == 3) {
printf("%d\n", query(1, l, r)) ;
}
if (op == 4) {
printf("%d\n", Query(1, l, r).mx[1]) ;
}
}
return 0 ;
}
/*
写代码时请注意:
1.ll?数组大小,边界?数据范围?
2.精度?
3.特判?
4.至少做一些
思考提醒:
1.最大值最小->二分?
2.可以贪心么?不行dp可以么
3.可以优化么
4.维护区间用什么数据结构?
5.统计方案是用dp?模了么?
6.逆向思维?
*/
加油ヾ(◍°∇°◍)ノ゙