CodeForces 797D Broken BST
$dfs$,线段树。
通过观察可以发现,某位置要能被找到,和他到根这条路上的每个节点的权值存在密切的联系,且是父节点的左儿子还是右儿子也有联系。
可以从根开始$dfs$,边走边更新线段树,如果遍历左儿子,那么将$[1,val-1]$全部加$1$,否则将$[val+1,n]$全部加$1$,回溯的时候减$1$,判断某位置能否到达可以比较单点值与深度的关系。
#include <iostream> #include <cstdio> #include <cmath> #include <cstring> #include <string> #include <queue> #include <stack> #include <vector> #include <algorithm> using namespace std; int f[400010]; int s[400010]; int res; void pushDown(int rt) { if(f[rt]==0) return ; s[2*rt] += f[rt]; s[2*rt+1] += f[rt]; f[2*rt] += f[rt]; f[2*rt+1] += f[rt]; f[rt] = 0; return ; } void pushUp(int rt) { s[rt] = s[2*rt] + s[2*rt+1]; } void update(int L,int R,int val,int l,int r,int rt) { if(L<=l&&r<=R) { s[rt] += val; f[rt] += val; return ; } int m = (l+r)/2; pushDown(rt); if(L<=m) update(L,R,val,l,m,2*rt); if(R>m) update(L,R,val,m+1,r,2*rt+1); pushUp(rt); } void query(int pos,int l,int r,int rt) { if(l==r) { res = s[rt]; return; } int m = (l+r)/2; pushDown(rt); if(pos<=m) query(pos,l,m,2*rt); else query(pos,m+1,r,2*rt+1); pushUp(rt); } int n; struct X { int val; int left,right; }node[100010]; int root; int b[100010],sz; int ans; int get(int x) { int L = 0,R = sz-1; while(L<=R) { int mid = (L+R)/2; if(b[mid]>x) R = mid-1; else if(b[mid] == x) return mid+1; else L = mid+1; } } int u[100010]; void dfs(int x,int y) { query(node[x].val,1,n,1); if(res != y) {} else u[node[x].val]=1; if(node[x].left!=-1) { if(node[x].val>1) update(1,node[x].val-1,1,1,n,1); dfs(node[x].left,y+1); if(node[x].val>1) update(1,node[x].val-1,-1,1,n,1); } if(node[x].right!=-1) { if(node[x].val<n) update(node[x].val+1,n,1,1,n,1); dfs(node[x].right,y+1); if(node[x].val<n) update(node[x].val+1,n,-1,1,n,1); } } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%d%d%d",&node[i].val,&node[i].left,&node[i].right); for(int i=1;i<=n;i++) b[sz++] = node[i].val; sort(b,b+sz); for(int i=1;i<=n;i++) { node[i].val = get(node[i].val); u[node[i].val]=1; } int sum=0; for(int i=1;i<=n;i++) sum=sum+1; for(int i=1;i<=n;i++) { if(node[i].left!=-1) f[node[i].left] = 1; if(node[i].right!=-1) f[node[i].right] = 1; } for(int i=1;i<=n;i++) { if(f[i]) continue; root = i; break; } memset(f,0,sizeof f); memset(u,0,sizeof u); dfs(root,0); for(int i=1;i<=n;i++) sum=sum-u[node[i].val]; printf("%d\n",sum); return 0; }