CodeForces 797D Broken BST

$dfs$,线段树。

通过观察可以发现,某位置要能被找到,和他到根这条路上的每个节点的权值存在密切的联系,且是父节点的左儿子还是右儿子也有联系。

可以从根开始$dfs$,边走边更新线段树,如果遍历左儿子,那么将$[1,val-1]$全部加$1$,否则将$[val+1,n]$全部加$1$,回溯的时候减$1$,判断某位置能否到达可以比较单点值与深度的关系。

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <string>
#include <queue>
#include <stack>
#include <vector>
#include <algorithm>
using namespace std;

int f[400010];
int s[400010];
int res;

void pushDown(int rt)
{
	if(f[rt]==0) return ;
	s[2*rt] += f[rt];
	s[2*rt+1] += f[rt];
	f[2*rt] += f[rt];
	f[2*rt+1] += f[rt];
	f[rt] = 0;
	return ;
}

void pushUp(int rt)
{
	s[rt] = s[2*rt] + s[2*rt+1];
}

void update(int L,int R,int val,int l,int r,int rt)
{
	if(L<=l&&r<=R)
	{
		s[rt] += val;
		f[rt] += val;
		return ;
	}

	int m = (l+r)/2;
	pushDown(rt);
	if(L<=m) update(L,R,val,l,m,2*rt);
	if(R>m)  update(L,R,val,m+1,r,2*rt+1);
	pushUp(rt);
}

void query(int pos,int l,int r,int rt)
{
	if(l==r)
	{
		res = s[rt];
		return;
	}

	int m = (l+r)/2;
	pushDown(rt);
	if(pos<=m) query(pos,l,m,2*rt);
	else query(pos,m+1,r,2*rt+1);
	pushUp(rt);

}

int n;
struct X
{
	int val;
	int left,right;
}node[100010];
int root;
int b[100010],sz;
int ans;

int get(int x)
{
	int L = 0,R = sz-1;

	while(L<=R)
	{
		int mid = (L+R)/2;
		if(b[mid]>x) R = mid-1;
		else if(b[mid] == x) return mid+1;
		else L = mid+1;
	}
}

int u[100010];

void dfs(int x,int y)
{
	query(node[x].val,1,n,1);
	if(res != y) {}
	else u[node[x].val]=1;

	if(node[x].left!=-1) 
	{
		if(node[x].val>1) update(1,node[x].val-1,1,1,n,1);
		dfs(node[x].left,y+1);
		if(node[x].val>1) update(1,node[x].val-1,-1,1,n,1);
	}

	if(node[x].right!=-1) 
	{
		if(node[x].val<n) update(node[x].val+1,n,1,1,n,1);
		dfs(node[x].right,y+1);
		if(node[x].val<n) update(node[x].val+1,n,-1,1,n,1);
	}
	
}

int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
		scanf("%d%d%d",&node[i].val,&node[i].left,&node[i].right);

	for(int i=1;i<=n;i++) b[sz++] = node[i].val;
	sort(b,b+sz);
	for(int i=1;i<=n;i++) 
	{
		node[i].val = get(node[i].val);
		u[node[i].val]=1;
	}

	int sum=0;
	for(int i=1;i<=n;i++) sum=sum+1;

	for(int i=1;i<=n;i++)
	{
		if(node[i].left!=-1) f[node[i].left] = 1;
		if(node[i].right!=-1) f[node[i].right] = 1;
	}

	for(int i=1;i<=n;i++)
	{
		if(f[i]) continue;
		root = i; break;
	}

	memset(f,0,sizeof f);
	memset(u,0,sizeof u);
	dfs(root,0);	
	for(int i=1;i<=n;i++) sum=sum-u[node[i].val];
	
	printf("%d\n",sum);
	
	return 0;
}

 

posted @ 2017-05-09 09:50  Fighting_Heart  阅读(240)  评论(0编辑  收藏  举报