【题解】[USACO17JAN]Promotion Counting P
\(\text{Solution:}\)
题目求的就是一棵子树中大于根节点权值的节点数。
这东西一看就很权值线段树。
然后发现这东西又很线段树合并。
考虑对每一个点维护一棵权值线段树。这样,我们将子树的信息合并到根的权值树上,就可以做到 \(n\log n\) 合并信息了。
然后对每个节点进行查询即可。权值树上维护一下区间和即可。由于空间问题要动态开点。
线段树合并的空间复杂度是 \(O(n\log n)\) 的,又有一个约为 \(2\) 的常数(从某处看到的)(这是动态开点线段树合并的复杂度)
所以开到 \(3.2\cdot 10^6\) 就可以了。
懒得思考怎么直接求,实际上可以用差分的思想,更新完一棵子树后,用更新前的信息减去更新后的信息即可。这样用树状数组也可以解决问题。
就当练习一下线段树合并了。
#include<bits/stdc++.h>
using namespace std;
const int MAXN=3.2e6+10;
const int N=2e5+10;
int head[MAXN],tot,cnt,n,m;
int ls[MAXN],rs[MAXN],ans[MAXN];
int dep[MAXN],pa[MAXN],siz[MAXN];
int b[MAXN],bcnt,blen;
int v[MAXN],val[MAXN];
struct E{int nxt,to;}e[MAXN];
inline void add(int x,int y){e[++tot]=(E){head[x],y};head[x]=tot;}
int rt[MAXN],sum[MAXN];
inline int getpos(int v){return lower_bound(b+1,b+blen+1,v)-b;}
void dfs(int x,int fa){
dep[x]=dep[fa]+1;
siz[x]=1;
for(int i=head[x];i;i=e[i].nxt){
int j=e[i].to;
if(j==fa)continue;
dfs(j,x);siz[x]+=siz[j];
}
}
inline void pushup(int x){sum[x]=sum[ls[x]]+sum[rs[x]];}
int change(int x,int l,int r,int pos,int v){
if(!x)x=++cnt;
if(l==r){
sum[x]+=v;
return x;
}
int mid=(l+r)>>1;
if(pos<=mid)ls[x]=change(ls[x],l,mid,pos,v);
else rs[x]=change(rs[x],mid+1,r,pos,v);
pushup(x);return x;
}
int merge(int x,int y,int l,int r){
if(!x||!y)return x+y;
if(l==r){
sum[x]+=sum[y];
return x;
}
int mid=(l+r)>>1;
ls[x]=merge(ls[x],ls[y],l,mid);
rs[x]=merge(rs[x],rs[y],mid+1,r);
pushup(x);return x;
}
int query(int x,int l,int r,int ql,int qr){
if(l>=ql&&r<=qr)return sum[x];
int mid=(l+r)>>1,ans=0;
if(ql<=mid)ans+=query(ls[x],l,mid,ql,qr);
if(mid<qr)ans+=query(rs[x],mid+1,r,ql,qr);
return ans;
}
void DFS(int x){
for(int i=head[x];i;i=e[i].nxt){
int j=e[i].to;
if(j==pa[x])continue;
DFS(j);
rt[x]=merge(rt[x],rt[j],1,N);
}
ans[x]=query(rt[x],1,N,val[x]+1,N);
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;++i)scanf("%d",&v[i]);
for(int i=1;i<=n;++i)b[++bcnt]=v[i];
sort(b+1,b+bcnt+1);
blen=unique(b+1,b+bcnt+1)-b-1;
for(int i=1;i<=n;++i){
int pos=getpos(v[i]);
val[i]=pos;
}
for(int i=1;i<=n;++i)rt[i]=change(rt[i],1,N,val[i],1);
for(int i=2;i<=n;++i){
int x;
scanf("%d",&x);
add(x,i);pa[i]=x;
}
dfs(1,0);
DFS(1);
for(int i=1;i<=n;++i)printf("%d\n",ans[i]);
return 0;
}