洛谷题单指南-线段树-P2572 [SCOI2010] 序列操作
原题链接:https://www.luogu.com.cn/problem/P2572
题意解读:对于01序列,支持几种操作:0.将区间值全变成0 1.将区间值全变成1 2.将区间值全部取反 3.查询1的个数 4.查找连续最多1的个数
解题思路:区间修改,区间查询,又是线段树的典型应用。
要查询的值有两个:1的个数,连续1的最大长度
1的个数即区间和,设定为sum,节点合并很简单,root.sum = left.sum + right.sum
连续1的个数设为maxOneCnt
要合并连续1的个数,需要一些额外信息:最大连续前缀1长度preOneCnt,最大连续后缀1长度subOneCnt
合并时,root.maxOneCnt在left.maxOneCnt、right.maxOneCnt、left.subOneCnt + right.preOneCnt中取最大值即可
而preOneCnt、subOneCnt的合并前文中已有多个例子,可以参考:https://www.cnblogs.com/jcwy/p/18582207
由于要支持区间修改全0、全1、取反,在操作之后,1的个数会和0的个数互相,因此还需要维护
连续0的个数设为maxZeroCnt
要合并连续1的个数,需要一些额外信息:最大连续前缀0长度preZeroCnt,最大连续后缀0长度subZeroCnt
合并时,root.maxZeroCnt在left.maxZeroCnt、right.maxZeroCnt、left.subZeroCnt + right.preZeroCnt中取最大值即可
preZeroCnt、subZeroCnt的合并方式与preOneCnt、subOneCnt类似。
由于要实现区间修改,因此需要懒标记记录修改的类型,
设定setval表示懒标记,含义为1:变成0,2:变成1,3:取反,默认值是0
综合以上,线段树节点定义为:
struct Node
{
int l, r;
int sum; //区间和
int maxOneCnt; //最多连续1的个数
int maxZeroCnt; //最多连续0的个数
int preOneCnt; //连续前缀1的个数
int preZeroCnt; //连续前缀0的个数
int subOneCnt; //连续后缀1的个数
int subZeroCnt; //连续后缀0的个数
int setval; //懒标记,表示将所有子节点 1:变成0,2:变成1,3:取反,默认值是0
} tr[N * 4];
pushup操作:
void pushup(Node &root ,Node &left, Node &right)
{
root.sum = left.sum + right.sum;
root.maxOneCnt = max(max(left.maxOneCnt, right.maxOneCnt), left.subOneCnt + right.preOneCnt);
root.maxZeroCnt = max(max(left.maxZeroCnt, right.maxZeroCnt), left.subZeroCnt + right.preZeroCnt);
root.preOneCnt = left.preOneCnt;
if(left.maxOneCnt == left.r - left.l + 1)
root.preOneCnt = left.maxOneCnt + right.preOneCnt;
root.preZeroCnt = left.preZeroCnt;
if(left.maxZeroCnt == left.r - left.l + 1)
root.preZeroCnt = left.maxZeroCnt + right.preZeroCnt;
root.subOneCnt = right.subOneCnt;
if(right.maxOneCnt == right.r - right.l + 1)
root.subOneCnt = right.maxOneCnt + left.subOneCnt;
root.subZeroCnt = right.subZeroCnt;
if(right.maxZeroCnt == right.r - right.l + 1)
root.subZeroCnt = right.maxZeroCnt + left.subZeroCnt;
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
修改节点打标记操作:
void addtag(int u, int val)
{
int len = tr[u].r - tr[u].l + 1;
if(val == 1) //全部改成0
{
tr[u].sum = 0;
tr[u].maxOneCnt = 0;
tr[u].maxZeroCnt = len;
tr[u].preOneCnt = 0;
tr[u].preZeroCnt = len;
tr[u].subOneCnt = 0;
tr[u].subZeroCnt = len;
tr[u].setval = val;
}
else if(val == 2) //全部改成1
{
tr[u].sum = len;
tr[u].maxOneCnt = len;
tr[u].maxZeroCnt = 0;
tr[u].preOneCnt = len;
tr[u].preZeroCnt = 0;
tr[u].subOneCnt = len;
tr[u].subZeroCnt = 0;
tr[u].setval = val;
}
else //全部取反
{
tr[u].sum = len - tr[u].sum;
swap(tr[u].maxOneCnt, tr[u].maxZeroCnt);
swap(tr[u].preOneCnt, tr[u].preZeroCnt);
swap(tr[u].subOneCnt, tr[u].subZeroCnt);
if(tr[u].setval == 0) tr[u].setval = val; //如果之前没有标记,则复制
else if(tr[u].setval == 1) tr[u].setval = 2; //如果之前是变成0,则改成变成1
else if(tr[u].setval == 2) tr[u].setval = 1; //如果之前是变成1,则改成变成0
else tr[u].setval = 0; //如果之前是取反,则重置为不做任何操作
}
}
100分代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 100005;
struct Node
{
int l, r;
int sum; //区间和
int maxOneCnt; //最多连续1的个数
int maxZeroCnt; //最多连续0的个数
int preOneCnt; //连续前缀1的个数
int preZeroCnt; //连续前缀0的个数
int subOneCnt; //连续后缀1的个数
int subZeroCnt; //连续后缀0的个数
int setval; //懒标记,表示将所有子节点 1:变成0,2:变成1,3:取反,默认值是0
} tr[N * 4];
int a[N];
int n, m;
void pushup(Node &root ,Node &left, Node &right)
{
root.sum = left.sum + right.sum;
root.maxOneCnt = max(max(left.maxOneCnt, right.maxOneCnt), left.subOneCnt + right.preOneCnt);
root.maxZeroCnt = max(max(left.maxZeroCnt, right.maxZeroCnt), left.subZeroCnt + right.preZeroCnt);
root.preOneCnt = left.preOneCnt;
if(left.maxOneCnt == left.r - left.l + 1)
root.preOneCnt = left.maxOneCnt + right.preOneCnt;
root.preZeroCnt = left.preZeroCnt;
if(left.maxZeroCnt == left.r - left.l + 1)
root.preZeroCnt = left.maxZeroCnt + right.preZeroCnt;
root.subOneCnt = right.subOneCnt;
if(right.maxOneCnt == right.r - right.l + 1)
root.subOneCnt = right.maxOneCnt + left.subOneCnt;
root.subZeroCnt = right.subZeroCnt;
if(right.maxZeroCnt == right.r - right.l + 1)
root.subZeroCnt = right.maxZeroCnt + left.subZeroCnt;
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
tr[u] = {l, r};
if(l == r)
{
tr[u].sum = a[l];
tr[u].maxOneCnt = a[l];
tr[u].maxZeroCnt = 1 - a[l];
tr[u].preOneCnt = a[l];
tr[u].preZeroCnt = 1 - a[l];
tr[u].subOneCnt = a[l];
tr[u].subZeroCnt = 1 - a[l];
}
else
{
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void addtag(int u, int val)
{
int len = tr[u].r - tr[u].l + 1;
if(val == 1) //全部改成0
{
tr[u].sum = 0;
tr[u].maxOneCnt = 0;
tr[u].maxZeroCnt = len;
tr[u].preOneCnt = 0;
tr[u].preZeroCnt = len;
tr[u].subOneCnt = 0;
tr[u].subZeroCnt = len;
tr[u].setval = val;
}
else if(val == 2) //全部改成1
{
tr[u].sum = len;
tr[u].maxOneCnt = len;
tr[u].maxZeroCnt = 0;
tr[u].preOneCnt = len;
tr[u].preZeroCnt = 0;
tr[u].subOneCnt = len;
tr[u].subZeroCnt = 0;
tr[u].setval = val;
}
else //全部取反
{
tr[u].sum = len - tr[u].sum;
swap(tr[u].maxOneCnt, tr[u].maxZeroCnt);
swap(tr[u].preOneCnt, tr[u].preZeroCnt);
swap(tr[u].subOneCnt, tr[u].subZeroCnt);
if(tr[u].setval == 0) tr[u].setval = val; //如果之前没有标记,则复制
else if(tr[u].setval == 1) tr[u].setval = 2; //如果之前是变成0,则改成变成1
else if(tr[u].setval == 2) tr[u].setval = 1; //如果之前是变成1,则改成变成0
else tr[u].setval = 0; //如果之前是取反,则重置为不做任何操作
}
}
void pushdown(int u)
{
if(tr[u].setval)
{
addtag(u << 1, tr[u].setval);
addtag(u << 1 | 1, tr[u].setval);
tr[u].setval = 0;
}
}
Node query(int u, int l, int r)
{
if(tr[u].l >= l && tr[u].r <= r) return tr[u];
else if(tr[u].l > r || tr[u].r < l) return Node{};
else
{
pushdown(u);
Node res = {};
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
pushup(res, left, right);
return res;
}
}
void update(int u, int l, int r, int val)
{
if(tr[u].l >= l && tr[u].r <= r) addtag(u, val);
else if(tr[u].l > r || tr[u].r < l) return;
else
{
pushdown(u);
update(u << 1, l, r, val);
update(u << 1 | 1, l, r, val);
pushup(u);
}
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
int op, l, r;
while(m--)
{
cin >> op >> l >> r;
l++, r++; //由于输入中下标从0开始,都加上1
if(op == 0) update(1, l, r, 1);
else if(op == 1) update(1, l, r, 2);
else if(op == 2) update(1, l, r, 3);
else if(op == 3) cout << query(1, l, r).sum << endl;
else cout << query(1, l, r).maxOneCnt << endl;
}
return 0;
}