GYM103373F(线段树,子段问题)

GYM103373F(线段树,子段问题)

题意

给定一个01串。定义01交替或者10交替的子段为交替串。

有两个操作:

  1. flip区间 \([l,r]\)
  2. 输出区间 \([l,r]\) 内交替串的数量

思路

  1. 如果有极长交替串长度,可以算出交替串数量

于是考虑维护交替串的长度。我们用线段树对其维护。

维护当前结点的答案 \(res\)

前缀极长交替串和后缀极长交替串 \(mxL,xmR\)

区间左右端点的数字 \(numL,numR\)

合并区间前需要判断是否可以将两个子区间的串拼接。具体的维护pushup如下

void pushup(int p) {
    tr[p].res = tr[lc].res + tr[rc].res;
    tr[p].mxL = tr[lc].mxL;
    tr[p].mxR = tr[rc].mxR;
    tr[p].numL = tr[lc].numL;
    tr[p].numR = tr[rc].numR;
    
    if(tr[lc].numR == (tr[rc].numL ^ 1)) {
        tr[p].res += tr[lc].mxR * tr[rc].mxL;
        
        if(tr[lc].mxL == tr[lc].r - tr[lc].l + 1)
            tr[p].mxL = tr[lc].mxL + tr[rc].mxL;
        
        if(tr[rc].mxR == tr[rc].r - tr[rc].l + 1)
            tr[p].mxR = tr[rc].mxR + tr[lc].mxR;
    }
}

对flip操作,翻转两次等于没有翻转,因此只要用一个tag标记该子区间是否翻转即可。

代码实现上需要注意在query的过程中,仍然需要判断左右区间中间部分能否合并,操作同pushup。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<random>
#include<functional>
#include<iomanip>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3fll
#define ull unsigned long long
#define endl '\n'
#define int long long
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
using PII = array<int,2>;
const int MAXN =10 + 2e5 ,mod=1e9 + 7;
int n,q;
int a[MAXN];

#define lc (p << 1)
#define rc (p << 1 | 1)
struct Node {
    int l,r;
    int res,mxL,mxR,numL,numR;
    int tag;
}tr[MAXN << 2];
void pushup(int p) {
    tr[p].res = tr[lc].res + tr[rc].res;
    tr[p].mxL = tr[lc].mxL;
    tr[p].mxR = tr[rc].mxR;
    tr[p].numL = tr[lc].numL;
    tr[p].numR = tr[rc].numR;
    
    if(tr[lc].numR == (tr[rc].numL ^ 1)) {
        tr[p].res += tr[lc].mxR * tr[rc].mxL;
        
        if(tr[lc].mxL == tr[lc].r - tr[lc].l + 1)
            tr[p].mxL = tr[lc].mxL + tr[rc].mxL;
        
        if(tr[rc].mxR == tr[rc].r - tr[rc].l + 1)
            tr[p].mxR = tr[rc].mxR + tr[lc].mxR;
    }
}
void upd(int p) {
    tr[p].numL ^= 1;
    tr[p].numR ^= 1;
    tr[p].tag ^= 1;
}
void pushdown(int p) {
    if(tr[p].tag) {
        upd(lc);
        upd(rc);
        tr[p].tag = 0;
    }
}
void build(int l,int r,int p = 1) {
    tr[p].l = l,tr[p].r = r;
    if(l == r) {
        tr[p].mxL = tr[p].mxR = tr[p].res = 1;
        tr[p].numL = tr[p].numR = a[l];
        tr[p].tag = 0;
        return;
    }
    int mid = l + r >> 1;
    build(l,mid,lc);
    build(mid + 1,r,rc);
    pushup(p);
}
void modify(int l,int r,int p = 1) {
    if(l <= tr[p].l and tr[p].r <= r) {
        upd(p);
        return;
    }
    int mid = tr[p].l + tr[p].r >> 1;
    pushdown(p);
    if(l <= mid)
        modify(l,r,lc);
    if(r > mid)
        modify(l,r,rc);
    pushup(p);
}
Node query(int l,int r,int p = 1) {
    if(l <= tr[p].l and tr[p].r <= r) 
        return tr[p];
    
    int mid = tr[p].l + tr[p].r >> 1;
    
    pushdown(p);
    
    Node ans;
    if(l <= mid and r > mid) {
        Node resL,resR;
        resL = query(l,r,lc);
        resR = query(l,r,rc);

        ans.l = resL.l;
        ans.r = resR.r;
        ans.res = resL.res + resR.res;
        ans.mxL = resL.mxL;
        ans.mxR = resR.mxR;
        ans.numL = resL.numL;
        ans.numR = resR.numR;
        
        if(resL.numR == (resR.numL ^ 1)) {
            ans.res += resL.mxR * resR.mxL;
            
            if(resL.mxL == resL.r - resL.l + 1)
                ans.mxL = resL.mxL + resR.mxL;
            
            if(resR.mxR == resR.r - resR.l + 1)
                ans.mxR = resR.mxR + resL.mxR;
        }
        
        return ans;
    }
    if(l <= mid) 
        return query(l,r,lc);
    return query(l,r,rc);
    
}
void solve()
{    
    cin >> n >> q;
    for(int i = 1;i <= n;i += 1)
        cin >> a[i];
    
    build(1,n);
    
    while(q --) {
        int op,l,r; cin >> op >> l >> r;
        if(op == 1) 
            modify(l,r);
        else
            cout << query(l,r).res << endl;
    }
}
signed main()
{
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);

    // int T;cin>>T;
    // while(T--)
        solve();

    return 0;
}
posted @ 2022-08-08 09:49  Mxrurush  阅读(30)  评论(0编辑  收藏  举报