LowbitMatrix(线段树)

题目

Problem - 7116

题解

一个数最多加\(\log n\)次lowbit,之后只需乘2即可。因此可以结合线段树暴力,没好的暴力加,加好的直接打标记乘2。

原本我的方法是并查集维护那些区间乘2,那些区间暴力加,并查集合并。这样做时间复杂度相似,但是常数巨大。除了并查集本身的复杂度,每次更新都是从线段树的\([1,n]\)第一层开始向下更新,这是常数很大的\(O(\log n)\)

因此直接在线段树里维护就好,用一个tag数组标记当前区间是需要递归下去加lowbit还是直接lazy标记乘2。

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 3e5 + 10;
const int M = 998244353;
const double eps = 1e-5;

int arr[N];
ll sum[N << 2], pw[N];
int tag[N << 2], lazy[N << 2];

ll lowbit(ll x) {
    return x&-x;
}

void pushdown(int rt) {
    if(lazy[rt]) {
        lazy[rt << 1] += lazy[rt];
        lazy[rt << 1 | 1] += lazy[rt];
        sum[rt << 1] = sum[rt << 1] * pw[lazy[rt]] % M;
        sum[rt << 1 | 1] = sum[rt << 1 | 1] * pw[lazy[rt]] % M;
        lazy[rt] = 0;
    }
}

void init(int l, int r, int rt) {
    if(l == r) {
        sum[rt] = arr[l];
        tag[rt] = (arr[l] == lowbit(arr[l]));
        return ;
    }
    int mid = (l + r) / 2;
    init(l, mid, rt << 1);
    init(mid + 1, r, rt << 1 | 1);
    sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % M;
    tag[rt] = (tag[rt << 1] && tag[rt << 1 | 1]);
}

void update(int l, int r, int L, int R, int rt) {
    if(tag[rt] && l >= L && r <= R) {
        lazy[rt]++;
        sum[rt] = sum[rt] * 2 % M;
        return ;
    }
    if(l == r) {
        sum[rt] += lowbit(sum[rt]); 
        if(sum[rt] == lowbit(sum[rt])) tag[rt] = 1;
        return ;
    }
    pushdown(rt);
    int mid = (l + r) /2;
    if(L <= mid) update(l, mid, L, R, rt << 1);
    if(R > mid) update(mid + 1, r, L, R, rt << 1 | 1);
    sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % M;
    tag[rt] = (tag[rt << 1] && tag[rt << 1 | 1]);
}

ll que(int l, int r, int L, int R, int rt) {
    if(l >= L && r <= R) {
        return sum[rt];
    }
    pushdown(rt);
    ll res = 0;
    int mid = (l + r) /2;
    if(L <= mid) res += que(l, mid, L, R, rt << 1);
    if(R > mid) res += que(mid + 1, r, L, R, rt << 1 | 1);
    sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % M;
    tag[rt] = (tag[rt << 1] && tag[rt << 1 | 1]);
    return res % M;
}


int main() {
    pw[0] = 1;
    for(int i = 1; i < N; i++) pw[i] = pw[i - 1] * 2 % M;
    IOS;
    int t;
    cin >> t;
    while(t--) {
        int n;
        cin >> n;
        for(int i = 1; i <= n; i++) cin >> arr[i];
        init(1, n, 1);
        int q;
        cin >> q;
        while(q--) {
            int op, l, r;
            cin >> op >> l >> r;
            if(op == 1) {
                update(1, n, l, r, 1);
            } else {
                cout << que(1, n, l, r, 1) << endl;
            }
        }
    }    
}
posted @ 2021-09-04 19:19  limil  阅读(37)  评论(0编辑  收藏  举报