洛谷题单指南-线段树的进阶用法-P3810 【模板】三维偏序(陌上花开)
原题链接:https://www.luogu.com.cn/problem/P3810
题意解读:题意很明显,有n组三元组,对于f(i),表示j!=i的情况下,所有的aj<=ai,bj<=bi,cj<=ci,这样的j的数量。求f(i)=0,1,2...n-1的i的个数。
解题思路:
先将三元组按a排序,a相同的按b排序,b相同的按c排序,排序都是从小到大。
设三元组结构体数组为
struct Record
{
int a, b, c;
} rs[N];
依次遍历每一个记录,显然对于记录rs[i]前面出现的记录rs[j]的a、b、c都不大于它。
那么,可以针对b值建立k个权值线段树,b的范围是1~k,每个权值线段树保存c值的数量。
权值线段树节点定义为
struct Node
{
int L, R; //左右子节点的下标
int cnt; //节点所表示值域区间的元素个数
} tr[M * 4 * 20];
每一颗线段树的根节点为:
int root[M];
依次遍历rs记录,对于记录rs[i],每次先查询在root[1]~root[rs[i].b]所有线段树中rs[i].c的数量,即可得到f(i)的值,将答案ans[f(i)]++;
然后将线段树root[rs[i].b]~root[k]中rs[i].c的数量加1,这一步时间复杂度较高,考虑优化。
对于这种前缀和查询、单调更新的操作,很容易想到树状数组优化,因此最终结构是用树状数组来维护所有权值线段树的根节点root[M]。
还有一点需要注意:
如果有多个连续rs[i]~rs[j]记录的a、b、c相同,那么这些f(i)值都一样,查询之前应该先将j-i个rs[]记录加入根为root[b]~root[k]中,也就是线段树中c值的个数增加j-i。
具体逻辑参考代码。
100分代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 100005, M = 200005;
struct Record
{
int a, b, c;
} rs[N];
struct Node
{
int L, R; //左右子节点的下标
int cnt; //节点所表示值域区间的元素个数
} tr[M * 4 * 20];
int root[M], idx;
int f[N], ans[N];
int n, k;
bool cmp(Record x, Record y)
{
if(x.a != y.a) return x.a < y.a;
else if(x.b != y.b) return x.b < y.b;
else return x.c < y.c;
}
int lowbit(int x)
{
return x & -x;
}
void pushup(int u)
{
tr[u].cnt = tr[tr[u].L].cnt + tr[tr[u].R].cnt;
}
//在根为u的权值线段树中,查询值为1~v的元素个数
int query(int u, int l, int r, int v)
{
if(l >= 1 && r <= v) return tr[u].cnt;
else if(l > v || r < 1) return 0;
else
{
int mid = l + r >> 1;
return query(tr[u].L, l, mid, v) + query(tr[u].R, mid + 1, r, v);
}
}
//在根为u的权值线段树中,将值v的个数加add,不使用复制而使用动态开点
int update(int u, int l, int r, int v, int add)
{
if(!u) u = ++idx;
if(l == r)
{
tr[u].cnt += add;
return u;
}
int mid = l + r >> 1;
if(v <= mid) tr[u].L = update(tr[u].L, l, mid, v, add);
else tr[u].R = update(tr[u].R, mid + 1, r, v, add);
pushup(u);
return u;
}
//利用树状数组,查询所有根节点是root[1]~root[r1]的权值线段树中值为1~r2的元素个数
int find(int r1, int r2)
{
int sum = 0;
for(int i = r1; i; i -= lowbit(i)) sum += query(root[i], 1, k, r2);
return sum;
}
//利用树状数组,在根节点是root[r1]~root[k]...的线段树中,将值r2的个数加add
void add(int r1, int r2, int add)
{
for(int i = r1; i <= k; i += lowbit(i)) root[i] = update(root[i], 1, k, r2, add);
}
int main()
{
cin >> n >> k;
for(int i = 1; i <= n; i++) cin >> rs[i].a >> rs[i].b >> rs[i].c;
sort(rs + 1, rs + n + 1, cmp);
for(int i = 1; i <= n; i++)
{
int j = i;
while(rs[j].a == rs[j+1].a && rs[j].b == rs[j+1].b && rs[j].c == rs[j+1].c) j++;
if(j > i) //存在多个连续rs[]相同,从rs[i]~rs[j]都相同
{
add(rs[i].b, rs[i].c, j - i); //所有相同rs[]的f()值要加上除自身自外的记录个数
int tmp = find(rs[i].b, rs[i].c); //查出相同rs[]的f()值
for(int st = i; st <= j; st++) f[st] = tmp, ans[f[st]]++; //给f(i)~f(j)赋相同的值
i = j;
}
else f[i] = find(rs[i].b, rs[i].c), ans[f[i]]++;
add(rs[i].b, rs[i].c, 1);
}
for(int i = 0; i < n; i++)
{
cout << ans[i] << endl;
}
return 0;
}