洛谷题单指南-线段树的进阶用法-P3810 【模板】三维偏序(陌上花开)

原题链接:https://www.luogu.com.cn/problem/P3810

题意解读:题意很明显,有n组三元组,对于f(i),表示j!=i的情况下,所有的aj<=ai,bj<=bi,cj<=ci,这样的j的数量。求f(i)=0,1,2...n-1的i的个数。

解题思路:

先将三元组按a排序,a相同的按b排序,b相同的按c排序,排序都是从小到大。

设三元组结构体数组为

struct Record
{
    int a, b, c;
} rs[N];

依次遍历每一个记录,显然对于记录rs[i]前面出现的记录rs[j]的a、b、c都不大于它。

那么,可以针对b值建立k个权值线段树,b的范围是1~k,每个权值线段树保存c值的数量。

权值线段树节点定义为

struct Node
{
    int L, R; //左右子节点的下标
    int cnt; //节点所表示值域区间的元素个数
} tr[M * 4 * 20];

每一颗线段树的根节点为:

int root[M];

依次遍历rs记录,对于记录rs[i],每次先查询在root[1]~root[rs[i].b]所有线段树中rs[i].c的数量,即可得到f(i)的值,将答案ans[f(i)]++;

然后将线段树root[rs[i].b]~root[k]中rs[i].c的数量加1,这一步时间复杂度较高,考虑优化。

对于这种前缀和查询、单调更新的操作,很容易想到树状数组优化,因此最终结构是用树状数组来维护所有权值线段树的根节点root[M]。

还有一点需要注意:

如果有多个连续rs[i]~rs[j]记录的a、b、c相同,那么这些f(i)值都一样,查询之前应该先将j-i个rs[]记录加入根为root[b]~root[k]中,也就是线段树中c值的个数增加j-i。

具体逻辑参考代码。

100分代码:

#include <bits/stdc++.h>
using namespace std;

const int N = 100005, M = 200005;

struct Record
{
    int a, b, c;
} rs[N];

struct Node
{
    int L, R; //左右子节点的下标
    int cnt; //节点所表示值域区间的元素个数
} tr[M * 4 * 20];

int root[M], idx;
int f[N], ans[N];
int n, k;

bool cmp(Record x, Record y)
{
    if(x.a != y.a) return x.a < y.a;
    else if(x.b != y.b) return x.b < y.b;
    else return x.c < y.c;
}

int lowbit(int x)
{
    return x & -x;
}

void pushup(int u)
{
    tr[u].cnt = tr[tr[u].L].cnt + tr[tr[u].R].cnt;
}

//在根为u的权值线段树中,查询值为1~v的元素个数
int query(int u, int l, int r, int v)
{
    if(l >= 1 && r <= v) return tr[u].cnt;
    else if(l > v || r < 1) return 0;
    else 
    {
        int mid = l + r >> 1;
        return query(tr[u].L, l, mid, v) + query(tr[u].R, mid + 1, r, v);
    }
}

//在根为u的权值线段树中,将值v的个数加add,不使用复制而使用动态开点
int update(int u, int l, int r, int v, int add)
{
    if(!u) u = ++idx;
    if(l == r) 
    {
        tr[u].cnt += add;
        return u;
    }
    int mid = l + r >> 1;
    if(v <= mid) tr[u].L = update(tr[u].L, l, mid, v, add);
    else tr[u].R = update(tr[u].R, mid + 1, r, v, add);
    pushup(u);
    return u;
}

//利用树状数组,查询所有根节点是root[1]~root[r1]的权值线段树中值为1~r2的元素个数
int find(int r1, int r2)
{
    int sum = 0;
    for(int i = r1; i; i -= lowbit(i)) sum += query(root[i], 1, k, r2);
    return sum;
}

//利用树状数组,在根节点是root[r1]~root[k]...的线段树中,将值r2的个数加add
void add(int r1, int r2, int add)
{
    for(int i = r1; i <= k; i += lowbit(i)) root[i] = update(root[i], 1, k, r2, add);
}

int main()  
{
    cin >> n >> k;
    for(int i = 1; i <= n; i++) cin >> rs[i].a >> rs[i].b >> rs[i].c;
    sort(rs + 1, rs + n + 1, cmp);
    for(int i = 1; i <= n; i++)
    {
        int j = i; 
        while(rs[j].a == rs[j+1].a && rs[j].b == rs[j+1].b && rs[j].c == rs[j+1].c) j++;
        if(j > i) //存在多个连续rs[]相同,从rs[i]~rs[j]都相同
        {
            add(rs[i].b, rs[i].c, j - i); //所有相同rs[]的f()值要加上除自身自外的记录个数
            int tmp = find(rs[i].b, rs[i].c); //查出相同rs[]的f()值
            for(int st = i; st <= j; st++) f[st] = tmp, ans[f[st]]++; //给f(i)~f(j)赋相同的值
            i = j;
        }
        else f[i] = find(rs[i].b, rs[i].c), ans[f[i]]++;
        add(rs[i].b, rs[i].c, 1); 
    }
    for(int i = 0; i < n; i++)
    {
        cout << ans[i] << endl;
    }
    return 0;
}

 

posted @ 2025-01-06 11:08  五月江城  阅读(4)  评论(0编辑  收藏  举报