GYM103373F(线段树,子段问题)
GYM103373F(线段树,子段问题)
题意
给定一个01串。定义01交替或者10交替的子段为交替串。
有两个操作:
- flip区间 \([l,r]\)
- 输出区间 \([l,r]\) 内交替串的数量
思路
- 如果有极长交替串长度,可以算出交替串数量
于是考虑维护交替串的长度。我们用线段树对其维护。
维护当前结点的答案 \(res\)
前缀极长交替串和后缀极长交替串 \(mxL,xmR\)
区间左右端点的数字 \(numL,numR\)
合并区间前需要判断是否可以将两个子区间的串拼接。具体的维护pushup如下
void pushup(int p) {
tr[p].res = tr[lc].res + tr[rc].res;
tr[p].mxL = tr[lc].mxL;
tr[p].mxR = tr[rc].mxR;
tr[p].numL = tr[lc].numL;
tr[p].numR = tr[rc].numR;
if(tr[lc].numR == (tr[rc].numL ^ 1)) {
tr[p].res += tr[lc].mxR * tr[rc].mxL;
if(tr[lc].mxL == tr[lc].r - tr[lc].l + 1)
tr[p].mxL = tr[lc].mxL + tr[rc].mxL;
if(tr[rc].mxR == tr[rc].r - tr[rc].l + 1)
tr[p].mxR = tr[rc].mxR + tr[lc].mxR;
}
}
对flip操作,翻转两次等于没有翻转,因此只要用一个tag标记该子区间是否翻转即可。
代码实现上需要注意在query的过程中,仍然需要判断左右区间中间部分能否合并,操作同pushup。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<random>
#include<functional>
#include<iomanip>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3fll
#define ull unsigned long long
#define endl '\n'
#define int long long
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 2e5 ,mod=1e9 + 7;
int n,q;
int a[MAXN];
#define lc (p << 1)
#define rc (p << 1 | 1)
struct Node {
int l,r;
int res,mxL,mxR,numL,numR;
int tag;
}tr[MAXN << 2];
void pushup(int p) {
tr[p].res = tr[lc].res + tr[rc].res;
tr[p].mxL = tr[lc].mxL;
tr[p].mxR = tr[rc].mxR;
tr[p].numL = tr[lc].numL;
tr[p].numR = tr[rc].numR;
if(tr[lc].numR == (tr[rc].numL ^ 1)) {
tr[p].res += tr[lc].mxR * tr[rc].mxL;
if(tr[lc].mxL == tr[lc].r - tr[lc].l + 1)
tr[p].mxL = tr[lc].mxL + tr[rc].mxL;
if(tr[rc].mxR == tr[rc].r - tr[rc].l + 1)
tr[p].mxR = tr[rc].mxR + tr[lc].mxR;
}
}
void upd(int p) {
tr[p].numL ^= 1;
tr[p].numR ^= 1;
tr[p].tag ^= 1;
}
void pushdown(int p) {
if(tr[p].tag) {
upd(lc);
upd(rc);
tr[p].tag = 0;
}
}
void build(int l,int r,int p = 1) {
tr[p].l = l,tr[p].r = r;
if(l == r) {
tr[p].mxL = tr[p].mxR = tr[p].res = 1;
tr[p].numL = tr[p].numR = a[l];
tr[p].tag = 0;
return;
}
int mid = l + r >> 1;
build(l,mid,lc);
build(mid + 1,r,rc);
pushup(p);
}
void modify(int l,int r,int p = 1) {
if(l <= tr[p].l and tr[p].r <= r) {
upd(p);
return;
}
int mid = tr[p].l + tr[p].r >> 1;
pushdown(p);
if(l <= mid)
modify(l,r,lc);
if(r > mid)
modify(l,r,rc);
pushup(p);
}
Node query(int l,int r,int p = 1) {
if(l <= tr[p].l and tr[p].r <= r)
return tr[p];
int mid = tr[p].l + tr[p].r >> 1;
pushdown(p);
Node ans;
if(l <= mid and r > mid) {
Node resL,resR;
resL = query(l,r,lc);
resR = query(l,r,rc);
ans.l = resL.l;
ans.r = resR.r;
ans.res = resL.res + resR.res;
ans.mxL = resL.mxL;
ans.mxR = resR.mxR;
ans.numL = resL.numL;
ans.numR = resR.numR;
if(resL.numR == (resR.numL ^ 1)) {
ans.res += resL.mxR * resR.mxL;
if(resL.mxL == resL.r - resL.l + 1)
ans.mxL = resL.mxL + resR.mxL;
if(resR.mxR == resR.r - resR.l + 1)
ans.mxR = resR.mxR + resL.mxR;
}
return ans;
}
if(l <= mid)
return query(l,r,lc);
return query(l,r,rc);
}
void solve()
{
cin >> n >> q;
for(int i = 1;i <= n;i += 1)
cin >> a[i];
build(1,n);
while(q --) {
int op,l,r; cin >> op >> l >> r;
if(op == 1)
modify(l,r);
else
cout << query(l,r).res << endl;
}
}
signed main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
// int T;cin>>T;
// while(T--)
solve();
return 0;
}