整体二分学习笔记

整体二分

引入

对于一堆询问,如果每个单独的询问都可以二分解决的话,时间复杂度为 \(O(NM\log N)\),但实际上每次二分都会有一些残留信息被我们扔掉,如果我们将所有询问一起二分,就可以最大时间的减小复杂度。

讲解

经典例题:区间第k大

给定一个序列 a 和一个整数 S,有 2 种操作:

1. 将 a 序列的第 k 个数变为 w
2. 查询区间[l, r]中有多少数小于等于 S

这个题可以用一个树状数组来维护,对于 \(a\) 序列,我们将所有小于等于 \(S\) 的数的位置在树状数组中 \(+1\),表示这个位置有一个小于等于 \(S\) 的数。

对于 1 操作,我们可以看作删除一个数再添加一个数,先看 \(a_k\) 是否大于 S,如果不大于则让这个位置 \(-1\),再看 \(w\) 是否大于 \(S\),如果小于等于就在这个位置 \(+1\)

对于 2 操作,可以直接执行 ask(r) - ask(l - 1),即为这个范围内有多少数小于等于 \(S\)

给定 a 序列,求第 k 小的数是几

这个问题也可以用二分来解决。每次二分一个 \(mid\),查询值域 \([l, r]\) 内有多少小于 \(mid\) 的数,记为 \(cnt\)。如果 \(cnt >= k\),可以在值域 \([l, mid]\) 中接着二分。如果 \(cnt < k\),可以令 k -= cnt,在值域 \([mid + 1, r]\) 中继续二分。

给定 a 序列,有 m 个询问,每次询问区间[l, r]中的第 k 小数

我们可以按照上面第二题的思路,每个询问进行一次二分,时间复杂度为 \(O(NM\log N)\),不能承受。

考虑每次二分值域,都会有一大部分信息被扔掉。所以我们应该对所有的询问全部进行二分。

算法流程如下:

  1. 值域达到边界,直接将当前的这些达到边界的询问的答案记录,返回即可。
  2. 二分出 mid,按照上面第一题的方法,用树状数组查找 \([l, r]\) 中不大于 \(mid\) 的数的个数,记为 \(cnt\)
  3. 再应用第二题的方法,将 \(k <= cnt\) 的询问放到 lq 序列中, 将 \(k > cnt\) 的询问的 k -= cnt,再放到 rq 序列中。
  4. 递归二分求解 lq 和 rq 序列。

同时,为了简化代码,将序列转化为 \(n\) 个插入操作,具体实现可以看上面的第一题。

代码:


#include <bits/stdc++.h>
using namespace std;

const int N = 1e6 + 10, INF = 0x3f3f3f3f;
struct Q
{
    int op, x, y, k;
}q[N], lq[N], rq[N];

int ans[N];
int tt;
int n, m;
vector<int> nums;

struct tree_array
{
    int c[N];
    #define lowbit(x) x & -x

    inline void add(int x, int val)
    {
        for(; x <= n; x += lowbit(x)) c[x] += val;
    }

    inline int query(int x)
    {
        int res = 0;
        for(; x; x -= lowbit(x)) res += c[x];
        return res;
    }
}bit;

void solve(int lval, int rval, int st, int ed)
{
    if(st > ed) return;
    if(lval == rval)
    {
        for(int i = st; i <= ed; i ++ )
            if(q[i].op > 0) ans[q[i].op] = lval;
        return;
    }

    int mid = lval + rval >> 1;
    int lt = 0, rt = 0;
    for(int i = st; i <= ed; i ++ )
    {
        if(q[i].op == 0)
        {
            if(q[i].y <= mid) bit.add(q[i].x, 1), lq[++ lt] = q[i];
            else rq[++ rt] = q[i];
        }
        else
        {
            int l = q[i].x, r = q[i].y;
            int cnt = bit.query(r) - bit.query(l - 1);
            if(cnt >= q[i].k) lq[++ lt] = q[i];
            else q[i].k -= cnt, rq[++ rt] = q[i];
        }
    }

    for(int i = ed; i >= st; i -- ) 
        if(q[i].op == 0 && q[i].y <= mid)
            bit.add(q[i].x, -1);
        
    for(int i = 1; i <= lt; i ++ ) q[st + i - 1] = lq[i];
    for(int i = 1; i <= rt; i ++ ) q[st + lt + i - 1] = rq[i];

    solve(lval, mid, st, st + lt - 1);
    solve(mid + 1, rval, st + lt, ed);
}

int main()
{
    n = read(), m = read();
    memset(bit.c, 0, sizeof bit.c);
    for(int i = 1; i <= n; i ++ )
    {
        int val;
        scanf("%d", &val);
        q[++ tt] = {0, i, val, 0};
    }

    for(int i = 1; i <= m; i ++ )
    {
        int l, r, k;
        scanf("%d%d%d", &l, &r, &k);
        q[++ tt] = {i, l, r, k};
    }

    solve(-INF, INF, 1, tt);

    for(int i = 1; i <= m; i ++ )
        printf("%d\n", ans[i]);
    
    return 0;
}

扩展:带修区间第 k 大

当然可以用树套树在线做,但是也可以用整体二分,而且运行起来更加优秀。

将每个修改操作看作 2 种操作,和上面的第一题一样,小于 \(mid\) 的就 \(-1\),添加的数小于 \(mid\)\(+1\)

时间复杂度 \(O(N\log N)\)

完整代码:


#include <bits/stdc++.h>
using namespace std;

const int N = 1e6 + 10, INF = 0x3f3f3f3f;
int n, m, tt, id;

struct tree_array
{
    int c[N];
    #define lowbit(x) x & -x

    inline void add(int x, int val)
    {
        for(; x <= n; x += lowbit(x)) c[x] += val;
    }

    inline int query(int x)
    {
        int res = 0;
        for(; x; x -= lowbit(x)) res += c[x];
        return res;
    }
} bit;

struct Q
{
    int op, x, y, k;
}q[N], lq[N], rq[N];

int ans[N], a[N];

void solve(int lval, int rval, int st, int ed)
{
    if(st > ed) return;
    if(lval == rval)
    {
        for(int i = st; i <= ed; i ++ )
            if(q[i].op > 0) 
                ans[q[i].op] = lval;
        return;
    }

    int mid = lval + rval >> 1;
    int lt = 0, rt = 0;
    for(int i = st; i <= ed; i ++ )
    {
        if(q[i].op <= 0)
        {
            if(q[i].y <= mid) bit.add(q[i].x, q[i].k), lq[++ lt] = q[i];
            else rq[++ rt] = q[i];
        }
        else
        {
            int l = q[i].x, r = q[i].y;
            int cnt = bit.query(r) - bit.query(l - 1);
            if(cnt >= q[i].k) lq[++ lt] = q[i];
            else q[i].k -= cnt, rq[++ rt] = q[i];
        }
    }

    for(int i = st; i <= ed; i ++ )
        if(q[i].op <= 0 && q[i].y <= mid)
            bit.add(q[i].x, -q[i].k);
    
    for(int i = 1; i <= lt; i ++ ) q[i + st - 1] = lq[i];
    for(int i = 1; i <= rt; i ++ ) q[i + lt + st - 1] = rq[i];

    solve(lval, mid, st, st + lt - 1);
    solve(mid + 1, rval, st + lt, ed);
}

int main()
{
    n = read(), m = read();

    for(int i = 1; i <= n; i ++ )
    {
        a[i] = read();
        q[++ tt] = {0, i, a[i], 1};
    }

    for(int i = 1; i <= m; i ++ )
    {
        char op[5];
        scanf("%s", op);
        
        if(op[0] == 'Q')
        {
            int l = read(), r = read(), k = read();
            q[++ tt] = {++ id, l, r, k};
        }
        else
        {
            int x = read(), y = read();
            q[++ tt] = {-1, x, a[x], -1};
            q[++ tt] = {0, x, y, 1};
            a[x] = y;
        }
    }

    solve(-INF, INF, 1, tt);

    for(int i = 1; i <= id; i ++ )
        printf("%d\n", ans[i]);
    
    return 0;
}

练习

P1527 [国家集训队]矩阵乘法

本题维护一个二维树状数组,然后和区间第 k 大没有一点不同。

struct tree_array
{
    int c[N][N];
    #define lowbit(x) x & -x

    inline void add(int x, int y, int val)
    {
        for(int i = x; i <= n; i += lowbit(i))
            for(int j = y; j <= n; j += lowbit(j))
                c[i][j] += val;
    }

    inline int query(int x, int y)
    {
        if(!x || !y) return 0;
        int res = 0;
        for(int i = x; i; i -= lowbit(i))
            for(int j = y; j; j -= lowbit(j))
                res += c[i][j];
        return res;
    }

    inline int ask(int x1, int y1, int x2, int y2)
    {
        return query(x2, y2) - query(x1 - 1, y2) - query(x2, y1 - 1) + query(x1 - 1, y1 - 1);
    }
} bit;

struct Q
{
    int op, x1, y1, x2, y2, k;
}q[M], lq[M], rq[M];
int ans[M];
int tt;

void solve(int lval, int rval, int st, int ed)
{
    if(st > ed) return;
    if(lval == rval)
    {
        for(int i = st; i <= ed; i ++ )
            if(q[i].op != 0)
                ans[q[i].op] = lval;
        return;
    }
    int mid = lval + rval >> 1;
    int lt = 0, rt = 0;
    for(int i = st; i <= ed; i ++ )
    {
        if(q[i].op == 0)
        {
            if(q[i].k <= mid) bit.add(q[i].x1, q[i].y1, 1), lq[++ lt] = q[i];
            else rq[++ rt] = q[i];
        }
        else
        {
            int cnt = bit.ask(q[i].x1, q[i].y1, q[i].x2, q[i].y2);
            if(cnt >= q[i].k) lq[++ lt] = q[i];
            else q[i].k -= cnt, rq[++ rt] = q[i];
        }
    }

    for(int i = st; i <= ed; i ++ )
        if(q[i].op == 0 && q[i].k <= mid)
            bit.add(q[i].x1, q[i].y1, -1);
    
    for(int i = 1; i <= lt; i ++ ) q[i + st - 1] = lq[i];
    for(int i = 1; i <= rt; i ++ ) q[st + lt + i - 1] = rq[i];

    solve(lval, mid, st, st + lt - 1);
    solve(mid + 1, rval, st + lt, ed);
}

P3527 [POI2011]MET-Meteors

很明显每个询问都能二分求解,时间长的肯定有更大概率收集全。

注意 \(l < r\) 的情况,这种情况可以将数组开成两倍,统计时加上长度即可。

统计每个国家收集多少的时候可以用图论的方式统计。

#include <bits/stdc++.h>
using namespace std;
#define int long long

const int N = 1e6 + 10;
int n, m, k;
int c[N], p[N], ans[N];

struct tree_array
{
    int c[N];
    #define lowbit(x) x & -x

    inline void add(int x, int val)
    {
        for(; x <= m * 2; x += lowbit(x)) c[x] += val;
    }

    inline int query(int x)
    {
        int res = 0;
        for(; x; x -= lowbit(x)) res += c[x];
        return res;
    }
} bit;

struct C
{
    int x, y, val;
}ch[N];

struct Q
{
    int c, k, h;
}q[N], lq[N], rq[N];

int e[N], ne[N], idx;

void add(int a, int b)
{
    e[++ idx] = b, ne[idx] = q[a].h, q[a].h = idx;
}

void solve(int lval, int rval, int st, int ed)
{
    if(st > ed) return;
    if(lval == rval)
    {
        for(int i = st; i <= ed; i ++ )
            ans[q[i].c] = lval;
        return;
    }

    int mid = lval + rval >> 1, lt = 0, rt = 0;
    for(int i = lval; i <= mid; i ++ ) 
    {
        bit.add(ch[i].x, ch[i].val), bit.add(ch[i].y + 1, -ch[i].val);
    }

    for(int i = st; i <= ed; i ++ )
    {
        int cnt = 0;
        for(int k = q[i].h; k && cnt <= q[i].k; k = ne[k])
        {
            int j = e[k];
            cnt += bit.query(j) + bit.query(j + m);
        }
        if(cnt >= q[i].k) lq[++ lt] = q[i];
        else q[i].k -= cnt, rq[++ rt] = q[i];
    }
    
    for(int i = lval; i <= mid; i ++ )
        bit.add(ch[i].x, -ch[i].val), bit.add(ch[i].y + 1, ch[i].val);

    for(int i = 1; i <= lt; i ++ ) q[i + st - 1] = lq[i];
    for(int i = 1; i <= rt; i ++ ) q[i + st + lt - 1] = rq[i];

    solve(lval, mid, st, st + lt - 1);
    solve(mid + 1, rval, st + lt, ed);
}

signed main()
{
    n = read(), m = read();

    for(int i = 1; i <= m; i ++ ) 
    {
        c[i] = read();
        add(c[i], i);
    }

    for(int i = 1; i <= n; i ++ )
    {
        q[i].k = read();
        q[i].c = i;
    }

    k = read();
    for(int i = 1; i <= k; i ++ )
    {
        int l = read(), r = read(), val = read();
        if(r < l) r += m;
        ch[i] = {l, r, val};
    }

    solve(1, k + 1, 1, n);

    for(int i = 1; i <= n; i ++ )
        if(ans[i] == k + 1) puts("NIE");
        else printf("%d\n", ans[i]);
    
    return 0;
}

P4602 [CTSC2018] 混合果汁

本题二分也很好想出来,唯一的难点在于怎么快速回答每个询问。

维护一个权值线段树,下标存储价格,内部存储物品的个数和总价格。

一开始先按美味值排序,维护一直到 mid 的前缀中的物品即可。


#include <bits/stdc++.h>
using namespace std;
#define int long long

const int N = 1e6 + 10;
int n, m, cur;
int ans[N];

struct segment
{
    int v, sum;
}t[N << 2];

inline void pushup(int p)
{
    t[p].v = t[p << 1].v + t[p << 1 | 1].v;
    t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum;
}

void change(int p, int l, int r, int pos, int v)
{
    if(l == r) 
    {
        t[p].v += v;
        t[p].sum = l * t[p].v;
        return;
    }

    int mid = l + r >> 1;
    if(pos <= mid) change(p << 1, l, mid, pos, v);
    else change(p << 1 | 1, mid + 1, r, pos, v);
    pushup(p);
}

int query(int p, int l, int r, int v)
{
    if(!v) return 0;
    if(l == r) return l * v;
    int mid = l + r >> 1;
    if(t[p << 1].v >= v) return query(p << 1, l, mid, v);
    else return t[p << 1].sum + query(p << 1 | 1, mid + 1, r, v - t[p << 1].v);
}

int query1(int p, int l, int r, int pos)
{
    if(l == r) return t[p].sum;
    int mid = l + r >> 1;
    if(pos <= mid) return query1(p << 1, l, mid, pos);
    else return query1(p << 1 | 1, mid + 1, r, pos);
}

struct data
{
    int d, p, l;

    bool operator<(const data &D) const
    {
        return d > D.d;
    }
} a[N];

struct Q
{
    int id, g, l;
} q[N], lq[N], rq[N];

void solve(int lval, int rval, int st, int ed)
{
    if (st > ed || lval > rval)
        return;
    if (lval == rval)
    {
        for (int i = st; i <= ed; i ++)
            ans[q[i].id] = a[lval].d;
        return;
    }

    int mid = lval + rval >> 1;
    while(cur < mid)
        cur ++, change(1, 1, N - 1, a[cur].p, a[cur].l);
    while(cur > mid)
        change(1, 1, N - 1, a[cur].p, -a[cur].l), cur --;

    int lt = 0, rt = 0;
    for(int i = st; i <= ed; i ++ )
    {
        if(q[i].l > t[1].v) rq[++ rt] = q[i];
        else if(query(1, 1, N - 1, q[i].l) <= q[i].g) lq[++ lt] = q[i];
        else rq[++ rt] = q[i];
    }

    for(int i = 1; i <= lt; i ++ ) q[i + st - 1] = lq[i];
    for(int i = 1; i <= rt; i ++ ) q[i + st + lt - 1] = rq[i];

    solve(lval, mid, st, st + lt - 1);
    solve(mid + 1, rval, st + lt, ed);
}

signed main()
{
    n = read(), m = read();

    for (int i = 1; i <= n; i++)
    {
        a[i].d = read(), a[i].p = read(), a[i].l = read();
    }

    a[++ n] = {-1, 0, 0x3f3f3f3f};
    sort(a + 1, a + n + 1);

    for (int i = 1; i <= m; i++)
    {
        int D = read(), L = read();
        q[i] = {i, D, L};
    }

    solve(1, n, 1, m);

    for (int i = 1; i <= m; i++)
        printf("%d\n", ans[i]);

    return 0;
}
posted @ 2023-07-10 18:10  crimson000  阅读(21)  评论(1编辑  收藏  举报