【洛谷P2633】Count on a tree

题目

题目链接:https://www.luogu.com.cn/problem/P2633
给定一棵 \(n\) 个节点的树,每个点有一个权值。有 \(m\) 个询问,每次给你 \(u,v,k\),你需要回答 \(u \text{ xor last}\)\(v\) 这两个节点间第 \(k\) 小的点权。

其中 \(\text{last}\) 是上一个询问的答案,定义其初始为 \(0\),即第一个询问的 \(u\) 是明文。

思路

感觉好久没写主席树了,正好发现聪爷写了一道主席树。发现挺好想的,就当练习模板了。
对于一组询问 \((u,v)\),设 \(p=\operatorname{lca}(u,v),q=fa[p]\),设 \(cnt[x]\) 表示 \(x\) 到根之间不超过 \(k\) 的数的个数,那么 \(ans=cnt[u]+cnt[v]-cnt[p]-cnt[q]\)
主席树乱搞即可。
时间复杂度 \(O(n\log n)\)

代码

#include <bits/stdc++.h>
using namespace std;

const int N=100010,LG=20;
int n,m,tot,last,head[N],rt[N],a[N],b[N],f[N][LG+1],dep[N];

struct edge
{
	int next,to;
}e[N*2];

void add(int from,int to)
{
	e[++tot].to=to;
	e[tot].next=head[from];
	head[from]=tot;
}

struct SegTree
{
	int tot,lc[N*LG*4],rc[N*LG*4],sum[N*LG*4];
	
	int ins(int now,int l,int r,int v)
	{
		int x=++tot;
		lc[x]=lc[now]; rc[x]=rc[now]; sum[x]=sum[now]+1;
		if (l==r) return x;
		int mid=(l+r)>>1;
		if (v<=mid) lc[x]=ins(lc[now],l,mid,v);
			else rc[x]=ins(rc[now],mid+1,r,v);
		return x;
	}
	
	int query(int x,int y,int p,int q,int l,int r,int k)
	{
		if (l==r) return l;
		int mid=(l+r)>>1,cnt=sum[lc[x]]+sum[lc[y]]-sum[lc[p]]-sum[lc[q]];
		if (cnt>=k) return query(lc[x],lc[y],lc[p],lc[q],l,mid,k);
			else return query(rc[x],rc[y],rc[p],rc[q],mid+1,r,k-cnt);
	}
}seg;

void dfs(int x,int fa)
{
	dep[x]=dep[fa]+1; f[x][0]=fa;
	for (int i=1;i<=LG;i++)
		f[x][i]=f[f[x][i-1]][i-1];
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa)
		{
			rt[v]=seg.ins(rt[x],1,n,a[v]);
			dfs(v,x);
		}
	}
}

int lca(int x,int y)
{
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=LG;i>=0;i--)
		if (dep[f[x][i]]>=dep[y]) x=f[x][i];
	if (x==y) return x;
	for (int i=LG;i>=0;i--)
		if (f[x][i]!=f[y][i])
			x=f[x][i],y=f[y][i];
	return f[x][0];
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&m);
	for (int i=1;i<=n;i++)
	{
		scanf("%d",&a[i]);
		b[i]=a[i];
	}
	sort(b+1,b+1+n);
	int WYCAKIOI=unique(b+1,b+1+n)-b-1;
	for (int i=1;i<=n;i++)
		a[i]=lower_bound(b+1,b+1+WYCAKIOI,a[i])-b;
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	rt[1]=seg.ins(0,1,n,a[1]);
	dfs(1,0);
	for (int i=1,u,v,k;i<=m;i++)
	{
		scanf("%d%d%d",&u,&v,&k);
		u^=last;
		int p=lca(u,v);
		printf("%d\n",last=b[seg.query(rt[u],rt[v],rt[p],rt[f[p][0]],1,n,k)]);
	}
	return 0;
}
posted @ 2020-10-03 01:11  stoorz  阅读(141)  评论(0编辑  收藏  举报