【BZOJ4919】[Lydsy六月月赛]大根堆 线段树合并

【BZOJ4919】[Lydsy六月月赛]大根堆

Description

给定一棵n个节点的有根树,编号依次为1到n,其中1号点为根节点。每个点有一个权值v_i。
你需要将这棵树转化成一个大根堆。确切地说,你需要选择尽可能多的节点,满足大根堆的性质:对于任意两个点i,j,如果i在树上是j的祖先,那么v_i>v_j。
请计算可选的最多的点数,注意这些点不必形成这棵树的一个连通子树。

Input

第一行包含一个正整数n(1<=n<=200000),表示节点的个数。
接下来n行,每行两个整数v_i,p_i(0<=v_i<=10^9,1<=p_i<i,p_1=0),表示每个节点的权值与父亲。

Output

输出一行一个正整数,即最多的点数。

Sample Input

6
3 0
1 1
2 1
3 1
4 1
5 1

Sample Output

5

题解:考虑用f[i][j]表示在i节点的子树中,最大值<=j,最多能选择多少点。如何转移呢?父亲节点的f数组可以看成儿子节点的f数组对应位置相加。然后再用 当前点权值-1处的f值 +1 来更新当前点权值后面的所有f值。

为此,我们可以考虑用线段树+标记永久化维护,我们要实现维护区间最大值。然后转移的时候可以直接用线段树合并搞定。细节还是比较多的。

upd:多说一点吧。有标记的线段树进行合并时也是比较恶心的。对于区间max标记,我们要进行标记永久化。这样的话每个节点的最大值标记对整个区间就都适用。在合并a和b的某个儿子时,如果这个儿子a有b没有,那么我们可以直接让a的标记对b的儿子生效,反之亦然。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn=200010;
int n,m,cnt,tot,ans;
int to[maxn],next[maxn],head[maxn],p[maxn],val[maxn],v[maxn],rt[maxn];
inline void add(int a,int b)
{
	to[cnt]=b,next[cnt]=head[a],head[a]=cnt++;
}
bool cmp(const int &a,const int &b)
{
	return val[a]<val[b];
}
struct sag
{
	int ls,rs,tag,sum;
}s[maxn<<6];
inline void pushdown(int x)
{
	if(s[x].ls)	s[s[x].ls].sum+=s[x].sum,s[s[x].ls].tag=max(s[s[x].ls].tag+s[x].sum,s[x].tag);
	if(s[x].rs)	s[s[x].rs].sum+=s[x].sum,s[s[x].rs].tag=max(s[s[x].rs].tag+s[x].sum,s[x].tag);
	s[x].sum=0;
}
int merge(int a,int b)
{
	if(!a||!b)	return a^b;
	pushdown(a),pushdown(b);
	if(!s[a].ls)	s[a].ls=s[b].ls,s[s[a].ls].tag+=s[a].tag,s[s[a].ls].sum+=s[a].tag+s[a].sum;
	else	if(!s[b].ls)	s[s[a].ls].tag+=s[b].tag,s[s[a].ls].sum+=s[b].tag+s[b].sum;
	else	s[a].ls=merge(s[a].ls,s[b].ls);
	if(!s[a].rs)	s[a].rs=s[b].rs,s[s[a].rs].tag+=s[a].tag,s[s[a].rs].sum+=s[a].tag+s[a].sum;
	else	if(!s[b].rs)	s[s[a].rs].tag+=s[b].tag,s[s[a].rs].sum+=s[b].tag+s[b].sum;
	else	s[a].rs=merge(s[a].rs,s[b].rs);
	s[a].tag+=s[b].tag;
	return a;
}
void updata(int l,int r,int &x,int a,int b,int c)
{
	if(!x)	x=++tot;
	if(a<=l&&r<=b)
	{
		s[x].tag=max(s[x].tag,c);
		return ;
	}
	pushdown(x);
	int mid=(l+r)>>1;
	if(a<=mid)	updata(l,mid,s[x].ls,a,b,c);
	if(b>mid)	updata(mid+1,r,s[x].rs,a,b,c);
}
int query(int l,int r,int x,int a)
{
	if(!x||!a)	return 0;
	if(l==r)	return s[x].tag;
	pushdown(x);
	int mid=(l+r)>>1;
	if(a<=mid)	return max(s[x].tag,query(l,mid,s[x].ls,a));
	return max(s[x].tag,query(mid+1,r,s[x].rs,a));
}
void dfs(int x)
{
	for(int i=head[x];i!=-1;i=next[i])	dfs(to[i]),rt[x]=merge(rt[x],rt[to[i]]);
	updata(1,m,rt[x],v[x],m,query(1,m,rt[x],v[x]-1)+1);
}
void find(int x)
{
	ans=max(ans,s[x].tag);
	pushdown(x);
	if(s[x].ls)	find(s[x].ls);
	if(s[x].rs)	find(s[x].rs);
}
inline int rd()
{
	int ret=0,f=1;	char gc=getchar();
	while(gc<'0'||gc>'9')	{if(gc=='-')	f=-f;	gc=getchar();}
	while(gc>='0'&&gc<='9')	ret=ret*10+(gc^'0'),gc=getchar();
	return ret*f;
}
int main()
{
	n=rd();
	int i,a;
	memset(head,-1,sizeof(head));
	for(i=1;i<=n;i++)
	{
		val[i]=rd(),a=rd(),p[i]=i;
		if(i!=1)	add(a,i);
	}
	sort(p+1,p+n+1,cmp);
	for(i=1;i<=n;i++)
	{
		if(i==1||val[p[i]]>val[p[i-1]])	m++;
		v[p[i]]=m;
	}
	dfs(1),find(rt[1]);
	printf("%d",ans);
	return 0;
}
posted @ 2017-12-09 13:34  CQzhangyu  阅读(1191)  评论(0编辑  收藏  举报