【知识】树套树

树套树

顾名思义,就是一个树套着一个树。

例如:线段树套平衡树,线段树中的每个节点的区间用平衡树维护。

常用:

  • 外层:线段树,树状数组
  • 内层:平衡树,线段树。(一般可以用 STL

例题:

  • AcWing 2488

    没啥好说的,线段树套 set

    #include <bits/stdc++.h>
    using namespace std;
    
    const int N = 50005, M = N << 2;
    const int INF = 0x3f3f3f3f;
    int n, m;
    struct Tree{
        int l, r;
        multiset<int> s;
    } tr[M];
    int w[N];
    
    void build(int u,int l,int r){
        tr[u] = {l, r};
        tr[u].s.insert(-INF), tr[u].s.insert(INF);
        for (int i = l; i <= r;i++)
            tr[u].s.insert(w[i]);
        int mid = l + r >> 1;
        if(l==r)
            return;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    }
    
    void change(int u,int p,int x){
        tr[u].s.erase(tr[u].s.find(w[p]));
        tr[u].s.insert(x);
        if(tr[u].l==tr[u].r)
            return;
        int mid = tr[u].l + tr[u].r >> 1;
        if(p<=mid)
            change(u << 1, p, x);
        else
            change(u << 1 | 1, p, x);
    }
    
    int query(int u,int a,int b,int x){
        if(tr[u].l>=a&&tr[u].r<=b){
            auto it = tr[u].s.lower_bound(x);
            --it;
            return *it;
        }
        int mid = tr[u].l + tr[u].r >> 1;
        int res = -INF;
        if(a<=mid)
            res = max(res, query(u << 1, a, b, x));
        if(b>mid)
            res = max(res, query(u << 1 | 1, a, b, x));
        return res;
    }
    int main(){
        cin >> n >> m;
        for (int i = 1; i <= n;i++)
            cin >> w[i];
        build(1, 1, n);
    
        while(m--){
            int op, a, b, x;
            cin >> op;
            if(op==x){
                cin >> a >> x;
                change(1, a, x);
                w[a] = x;
            }
            else{
                cin >> a >> b >> x;
                cout << query(1, a, b, x) << endl;
            }
        }
        return 0;
    }
    
  • P3380 【模板】树套树

    #include <bits/stdc++.h>
    using namespace std;
    
    const int N = 2000005, INF = 2147483647;
    int n, m;
    struct Node{
        int s[2], p, v;
        int sz;
        void init(int _v,int _p){
            v = _v, p = _p;
            sz = 1;
        }
    } tr[N];
    int L[N], R[N], T[N], idx;
    int w[N];
    
    void pushup(int x){
        tr[x].sz = tr[tr[x].s[0]].sz + tr[tr[x].s[1]].sz + 1;
    }
    
    void rotate(int x){
        int y = tr[x].p, z = tr[y].p;
        int k = tr[y].s[1] == x;
        tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
        tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
        tr[x].s[k ^ 1] = y, tr[y].p = x;
        pushup(y), pushup(x);
    }
    
    void splay(int &root,int x,int k){
        while(tr[x].p!=k){
            int y = tr[x].p, z = tr[y].p;
            if(z!=k)
                if((tr[y].s[1]==x)^(tr[z].s[1]==y))
                    rotate(x);
                else
                    rotate(y);
            rotate(x);
        }
        if(!k)
            root = x;
    }
    
    void insert(int &root,int v){
        int u = root, p = 0;
        while(u)
            p = u, u = tr[u].s[v > tr[u].v];
        u = ++idx;
        if(p)
            tr[p].s[v > tr[p].v] = u;
        tr[u].init(v, p);
        splay(root, u, 0);
    }
    
    int get_k(int root,int v){
        int u = root, res = 0;
        while(u){
            if(tr[u].v<v)
                res += tr[tr[u].s[0]].sz + 1, u = tr[u].s[1];
            else
                u = tr[u].s[0];
        }
        return res;
    }
    
    void update(int &root,int x,int y){
        int u = root;
        while(u){
            if(tr[u].v==x)
                break;
            if(tr[u].v<x)
                u = tr[u].s[1];
            else
                u = tr[u].s[0];
        }
        splay(root, u, 0);
        int l = tr[u].s[0], r = tr[u].s[1];
        while(tr[l].s[1])
            l = tr[l].s[1];
        while(tr[r].s[0])
            r = tr[r].s[0];
        splay(root, l, 0), splay(root, r, l);
        tr[r].s[0] = 0;
        pushup(l), pushup(r);
        insert(root, y);
    }
    
    void build(int u,int l,int r){
        L[u] = l, R[u] = r;
        insert(T[u], INF), insert(T[u], -INF);
        for (int i = l; i <= r;i++)
            insert(T[u], w[i]);
        if(l==r)
            return;
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    }
    
    int query(int u,int a,int b,int x){
        if(L[u]>=a&&R[u]<=b)
            return get_k(T[u], x) - 1;
        int mid = L[u] + R[u] >> 1;
        int res = 0;
        if(a<=mid)
            res += query(u<<1, a, b, x);
        if(b>mid)
            res += query(u << 1 | 1, a, b, x);
        return res;
    }
    
    void change(int u,int p,int x){
        update(T[u], w[p], x);
        if(L[u]==R[u])
            return;
        int mid = L[u] + R[u] >> 1;
        if(p<=mid)
            change(u << 1, p, x);
        else
            change(u << 1 | 1, p, x);
    }
    
    int get_pre(int root,int v){
        int u = root, res = -INF;
        while(u){
            if(tr[u].v<v)
                res = max(res, tr[u].v), u = tr[u].s[1];
            else
                u = tr[u].s[0];
        }
        return res;
    }
    
    int get_suc(int root,int v){
        int u = root, res = INF;
        while(u){
            if(tr[u].v>v)
                res = min(res, tr[u].v), u = tr[u].s[0];
            else
                u = tr[u].s[1];
        }
        return res;
    }
    int query_pre(int u,int a,int b,int x){
        if(L[u]>=a&&R[u]<=b)
            return get_pre(T[u], x);
        int mid = L[u] + R[u] >> 1;
        int res = -INF;
        if(a<=mid)
            res = max(res, query_pre(u << 1, a, b, x));
        if(b>mid)
            res = max(res, query_pre(u << 1 | 1, a, b, x));
        return res;
    }
    
    int query_suc(int u,int a,int b,int x){
        if(L[u]>=a&&R[u]<=b)
            return get_suc(T[u], x);
        int mid = L[u] + R[u] >> 1;
        int res = INF;
        if(a<=mid)
            res = min(res, query_suc(u << 1, a, b, x));
        if(b>mid)
            res = min(res, query_suc(u << 1 | 1, a, b, x));
        return res;
    }
    
    int main(){
        cin >> n >> m;
        for (int i = 1; i <= n;i++)
            cin >> w[i];
        build(1, 1, n);
    
        while(m--){
            int op, a, b, x;
            cin >> op;
            if(op==1){
                cin >> a >> b >> x;
                cout << query(1, a, b, x) +1 << endl;
            }
            else if(op==2){
                cin >> a >> b >> x;
                int l = 0, r = 1e8;
                while(l<r){
                    int mid = l + r + 1 >> 1;
                    if(query(1,a,b,mid)+1<=x)
                        l = mid;
                    else
                        r = mid - 1;
                }
                cout << r << endl;
            }
            else if(op==3){
                cin >> a >> x;
                change(1, a, x);
                w[a] = x;
            }
            else if(op==4){
                cin >> a >> b >> x;
                cout << query_pre(1, a, b, x) << endl;
            }
            else{
                cin >> a >> b >> x;
                cout << query_suc(1, a, b, x) << endl;
            }
        }
        return 0;
    }
    
  • P3332 [ZJOI2013] K大数查询

    考虑值域线段树套线段树。

    Tips:标记持久化,动态开店线段树。

    #include <iostream>
    #include <cstring>
    #include <cstdio>
    #include <algorithm>
    #include <vector>
    
    using namespace std;
    
    typedef long long LL;
    
    const int N = 50010, P = N * 17 * 17, M = N * 4;
    
    int n, m;
    struct Tree
    {
        int l, r;
        LL sum, add;
    }tr[P];
    int L[M], R[M], T[M], idx;
    struct Query
    {
        int op, a, b, c;
    }q[N];
    vector<int> nums;
    
    int get(int x)
    {
        return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
    }
    
    void build(int u, int l, int r)
    {
        L[u] = l, R[u] = r, T[u] = ++ idx;
        if (l == r) return;
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    }
    
    int intersection(int a, int b, int c, int d)
    {
        return min(b, d) - max(a, c) + 1;
    }
    
    void update(int u, int l, int r, int pl, int pr)
    {
        tr[u].sum += intersection(l, r, pl, pr);
        if (l >= pl && r <= pr)
        {
            tr[u].add ++ ;
            return;
        }
        int mid = l + r >> 1;
        if (pl <= mid)
        {
            if (!tr[u].l) tr[u].l = ++ idx;
            update(tr[u].l, l, mid, pl, pr);
        }
        if (pr > mid)
        {
            if (!tr[u].r) tr[u].r = ++ idx;
            update(tr[u].r, mid + 1, r, pl, pr);
        }
    }
    
    void change(int u, int a, int b, int c)
    {
        update(T[u], 1, n, a, b);
        if (L[u] == R[u]) return;
        int mid = L[u] + R[u] >> 1;
        if (c <= mid) change(u << 1, a, b, c);
        else change(u << 1 | 1, a, b, c);
    }
    
    LL get_sum(int u, int l, int r, int pl, int pr, int add)
    {
        if (l >= pl && r <= pr) return tr[u].sum + (r - l + 1LL) * add;
        int mid = l + r >> 1;
        LL res = 0;
        add += tr[u].add;
        if (pl <= mid)
        {
            if (tr[u].l) res += get_sum(tr[u].l, l, mid, pl, pr, add);
            else res += intersection(l, mid, pl, pr) * add;
        }
        if (pr > mid)
        {
            if (tr[u].r) res += get_sum(tr[u].r, mid + 1, r, pl, pr, add);
            else res += intersection(mid + 1, r, pl, pr) * add;
        }
        return res;
    }
    
    int query(int u, int a, int b, int c)
    {
        if (L[u] == R[u]) return R[u];
        int mid = L[u] + R[u] >> 1;
        LL k = get_sum(T[u << 1 | 1], 1, n, a, b, 0);
        if (k >= c) return query(u << 1 | 1, a, b, c);
        return query(u << 1, a, b, c - k);
    }
    
    int main()
    {
        scanf("%d%d", &n, &m);
        for (int i = 0; i < m; i ++ )
        {
            scanf("%d%d%d%d", &q[i].op, &q[i].a, &q[i].b, &q[i].c);
            if (q[i].op == 1) nums.push_back(q[i].c);
        }
        sort(nums.begin(), nums.end());
        nums.erase(unique(nums.begin(), nums.end()), nums.end());
    
        build(1, 0, nums.size() - 1);
    
        for (int i = 0; i < m; i ++ )
        {
            int op = q[i].op, a = q[i].a, b = q[i].b, c = q[i].c;
            if (op == 1) change(1, a, b, get(c));
            else printf("%d\n", nums[query(1, a, b, c)]);
        }
    
        return 0;
    }
    
posted @   Star_F  阅读(28)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 没有源码,如何修改代码逻辑?
· NetPad:一个.NET开源、跨平台的C#编辑器
· PowerShell开发游戏 · 打蜜蜂
· 凌晨三点救火实录:Java内存泄漏的七个神坑,你至少踩过三个!
点击右上角即可分享
微信分享提示