bzoj2588 count on a tree
【写在前面】
这道gou题的一大坑点:由于强制在线,如果你某次输出的答案不对,下一次更新u的时候就很容易导致Re。。。
所以你Re了不急着改大数组,先看看自己是不是wa了。。。
【题目大意】
给你一棵树,求某条链上第k小的点权。。。
【题解】
主席数的基本操作,以rt[i]为根节点的线段树存的是树上从1号节点到i号节点的信息。。。
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; int n,m,q,cnt,tot,val,lastans; int head[100005]; int rt[100005]; int fa[100005]; int dep[100005]; int siz[100005]; int son[100005]; int grand[100005]; int ans[100005]; struct node{ int v; int w; int id; }poi[100005]; struct Tree{ int num; int ls; int rs; }tr[5000005]; struct Edge{ int fr; int to; int nxt; }edge[200005]; void init(){ dep[1]=1; memset(head,-1,sizeof(head)); } void addedge(int f,int t){ cnt++; edge[cnt].fr=f; edge[cnt].to=t; edge[cnt].nxt=head[f]; head[f]=cnt; } int cmp1(node a,node b){ return a.v<b.v; } int cmp2(node a,node b){ return a.id<b.id; } void insert(int &x,int y,int l,int r,int p){ x=++tot; if(l==r){ tr[x].num=tr[y].num+1; return; } int mid=(l+r)>>1; tr[x].ls=tr[y].ls; tr[x].rs=tr[y].rs; if(p<=mid)insert(tr[x].ls,tr[y].ls,l,mid,p); else insert(tr[x].rs,tr[y].rs,mid+1,r,p); tr[x].num=tr[tr[x].ls].num+tr[tr[x].rs].num; } void dfs1(int u){ siz[u]=1; for(int i=head[u];i!=-1;i=edge[i].nxt){ int v=edge[i].to; if(v==fa[u])continue; fa[v]=u; dep[v]=dep[u]+1; dfs1(v); siz[u]+=siz[v]; if(siz[v]>siz[son[u]]){ son[u]=siz[v]; } } } void dfs2(int u){ if(u!=son[fa[u]])grand[u]=u; else grand[u]=grand[fa[u]]; for(int i=head[u];i!=-1;i=edge[i].nxt){ int v=edge[i].to; if(v==fa[u])continue; dfs2(v); } } void dfs3(int u){ insert(rt[u],rt[fa[u]],1,val,poi[u].w); for(int i=head[u];i!=-1;i=edge[i].nxt){ int v=edge[i].to; if(v==fa[u])continue; dfs3(v); } } int lca(int x,int y){ while(grand[x]!=grand[y]){ if(dep[grand[x]]<dep[grand[y]]){ swap(x,y); } x=fa[grand[x]]; } if(dep[x]>dep[y]){ swap(x,y); } return x; } int query(int u,int v,int f,int g,int l,int r,int k){ if(l==r){ return ans[l]; } int mid=(l+r)>>1; int num=tr[tr[u].ls].num+tr[tr[v].ls].num-tr[tr[f].ls].num-tr[tr[g].ls].num; if(k<=num){ return query(tr[u].ls,tr[v].ls,tr[f].ls,tr[g].ls,l,mid,k); } else{ return query(tr[u].rs,tr[v].rs,tr[f].rs,tr[g].rs,mid+1,r,k-num); } } int main(){ init(); scanf("%d%d",&n,&m); for(int i=1;i<=n;i++){ scanf("%d",&poi[i].v); poi[i].id=i; } poi[0].v=-1; sort(poi+1,poi+n+1,cmp1); for(int i=1;i<=n;i++){ if(poi[i].v!=poi[i-1].v){ val++; ans[val]=poi[i].v; } poi[i].w=val; } sort(poi+1,poi+n+1,cmp2); for(int i=1;i<n;i++){ int u,v; scanf("%d%d",&u,&v); addedge(u,v); addedge(v,u); } dfs1(1); dfs2(1); dfs3(1); for(int i=1;i<=m;i++){ int u,v,k,f,g; scanf("%d%d%d",&u,&v,&k); u^=lastans; f=lca(u,v); g=fa[f]; lastans=query(rt[u],rt[v],rt[f],rt[g],1,val,k); printf("%d",lastans); if(i!=m){ printf("\n"); } } return 0; }