[BZOJ 3110] [Zjoi2013] K大数查询 【树套树】

题目链接: BZOJ - 3110

 

题目分析

这道题是一道树套树的典型题目,我们使用线段树套线段树,一层是区间线段树,一层是权值线段树。一般的思路是外层用区间线段树,内层用权值线段树,但是这样貌似会很难写。多数题解都使用了外层权值线段树,内层区间线段树,于是我就这样写了。每次插入会在 logn 棵线段树中一共建 log^2(n) 个结点,所以空间应该开到 O(nlog^2(n)) 。由于这道题查询的是区间第 k 大,所以我们存在线段树中的数值是输入数值的相反数(再加上 n 使其为正数),这样查第 k 小就可以了。在查询区间第 k 大值的时候,我们用类似二分的方法,一层一层地逼近答案。

写代码的时候出现的错误:在每一棵区间线段树中修改数值的时候,应该调用的是像 Insert(Lc[x], 1, n, l, r) 这样子,但我经常写成 Insert(x << 1, s, t, l, r) 之类的。注意!

 

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>

using namespace std;

const int MaxN = 100000 + 5, MaxM = 100000 * 16 * 16 + 5;

int n, m, f, a, b, c, Index, Ans;
int Root[MaxN * 4], Lc[MaxM], Rc[MaxM], Sum[MaxM], Lazy[MaxM];

inline int gmin(int a, int b) {
	return a < b ? a : b;
}
inline int gmax(int a, int b) {
	return a > b ? a : b;
}

int Get(int x, int s, int t, int l, int r) {
	if (l <= s && r >= t) return Sum[x];
	int p = 0, q = 0, m = (s + t) >> 1;
	if (l <= m) p = Get(Lc[x], s, m, l, r);
	if (r >= m + 1) q = Get(Rc[x], m + 1, t, l, r);
	return (p + q + Lazy[x] * (gmin(t, r) - gmax(s, l) + 1));
}

int GetKth(int l, int r, int k) {
	int s = 1, t = n * 2, m, x = 1, Temp;
	while (s != t) {
		m = (s + t) >> 1;
		if ((Temp = Get(Root[x << 1], 1, n, l, r)) >= k) {
			t = m; x = x << 1;
		}
		else {
			s = m + 1; x = x << 1 | 1; k -= Temp;
		}
	}
	return s;
}

void Insert(int &x, int s, int t, int l, int r) {
	if (x == 0) x = ++Index;
	if (l <= s && r >= t) {
		Sum[x] += t - s + 1;
		++Lazy[x];
		return;
	}
	int m = (s + t) >> 1;
	if (l <= m) Insert(Lc[x], s, m, l, r);
	if (r >= m + 1) Insert(Rc[x], m + 1, t, l, r);
	Sum[x] = Sum[Lc[x]] + Sum[Rc[x]] + Lazy[x] * (t - s + 1);
}

void Add(int l, int r, int Num) {
	int s = 1, t = n * 2, m, x = 1;
	while (s != t) {
		Insert(Root[x], 1, n, l, r);
		m = (s + t) >> 1;	
		if (Num <= m) {
			t = m;
			x = x << 1;
		}
		else {
			s = m + 1;
			x = x << 1 | 1;
		}
	}
	Insert(Root[x], 1, n, l, r);
}

int main() 
{
	scanf("%d%d", &n, &m);
	Index = 0;
	for (int i = 1; i <= m; ++i) {
		scanf("%d%d%d%d", &f, &a, &b, &c);
		if (f == 1) {
			c = -c + n + 1;
			Add(a, b, c);
		}
		else {
			Ans = GetKth(a, b, c);
			Ans = -Ans + n + 1;
			printf("%d\n", Ans);
		}
	}	
	return 0;
}

  

posted @ 2014-12-18 21:29  JoeFan  阅读(1193)  评论(0编辑  收藏  举报