Count on a tree

link

本来想打一个树上主席树放松一下大脑,结果血压上来了。

其实说白了它就是一个模板,只是有一件事是需要格外留意的:

树上差分点权应该是 \(v(s1)+v(s2)-v(lca)-v(fa(lca))\) ,而树上边权差分应该是(下放到点权之后) \(v(s1)+v(s2)-v(lca)\times2\) 。二者是不一样的,一定要记清楚了!!!

其它没什么了。吐槽一下树上主席树的work函数真不优雅,带4个根进去丑的要命。

#include<cstdio>
#include<algorithm>
#define zczc
using namespace std;
const int N=100010;
inline void read(int &wh){
	wh=0;int f=1;char w=getchar();
	while(w<'0'||w>'9'){if(w=='-')f=-1;w=getchar();}
	while(w>='0'&&w<='9'){wh=wh*10+w-'0';w=getchar();}
	wh*=f;return;
}
inline void swap(int &s1,int &s2){
	int s3=s1;s1=s2;s2=s3;return;
}

int m,n,a[N],b[N];

struct edge{
	int t,next;
}e[N<<1];
int head[N],esum;
inline void add(int fr,int to){
	e[++esum]=(edge){to,head[fr]};head[fr]=esum;
}

#define mid (t[wh].l+t[wh].r>>1)
struct node{
	int lc,rc,l,r,data;
}t[N<<5];
int cnt;
inline int build(int l,int r){
	int wh=++cnt;t[wh].l=l,t[wh].r=r;
	if(l^r)t[wh].lc=build(l,mid),t[wh].rc=build(mid+1,r);
	return wh;
}
inline int insert(int x,int pl){
	int wh=++cnt;t[wh]=t[x];t[wh].data++;
	if(t[wh].l==t[wh].r)return wh;
	if(pl<=mid)t[wh].lc=insert(t[x].lc,pl);
	else t[wh].rc=insert(t[x].rc,pl);return wh;
}
inline int work(int r0,int r1,int r2,int r3,int k){
	if(t[r0].l==t[r0].r)return t[r0].l;
	int l0=t[r0].lc,l1=t[r1].lc,l2=t[r2].lc,l3=t[r3].lc;
	int now=t[l2].data+t[l3].data-t[l1].data-t[l0].data;
	if(now>=k)return work(l0,l1,l2,l3,k);
	else return work(t[r0].rc,t[r1].rc,t[r2].rc,t[r3].rc,k-now);
}
#undef mid

int d[N],root[N],nxt[N][27];
void dfs(int wh,int fa){
	d[wh]=d[fa]+1,nxt[wh][0]=fa;
	for(int i=1;i<=25;i++)nxt[wh][i]=nxt[nxt[wh][i-1]][i-1];
	root[wh]=insert(root[fa],a[wh]);
	for(int i=head[wh],th;i;i=e[i].next){
		if((th=e[i].t)==fa)continue;dfs(th,wh);
	}
}
int lca(int s1,int s2){
	if(d[s1]<d[s2])swap(s1,s2);
	for(int i=25;i>=0;i--)
		if(d[nxt[s1][i]]>=d[s2])s1=nxt[s1][i];
	if(s1==s2)return s1;
	for(int i=25;i>=0;i--)
		if(nxt[s1][i]^nxt[s2][i])s1=nxt[s1][i],s2=nxt[s2][i];
	return nxt[s1][0];
}

signed main(){
	
	#ifdef zczc
	freopen("in.txt","r",stdin);
	#endif
	
	read(m);read(n);int s1,s2,s3,lan=0;
	for(int i=1;i<=m;i++){read(a[i]);b[i]=a[i];}
	for(int i=1;i<m;i++){read(s1);read(s2);add(s1,s2);add(s2,s1);}
	sort(b+1,b+m+1);int num=unique(b+1,b+m+1)-b;
	for(int i=1;i<=m;i++)a[i]=lower_bound(b+1,b+num+1,a[i])-b;
	root[0]=build(1,num);d[0]=1;dfs(1,0);
	for(int i=1;i<=n;i++){
		read(s1);read(s2);read(s3);s1^=lan;int l=lca(s1,s2);
		int now=work(root[nxt[l][0]],root[l],root[s1],root[s2],s3);
		printf("%d\n",lan=b[now]);
	}
	
	return 0;
}
posted @ 2022-04-30 14:18  Feyn618  阅读(15)  评论(0编辑  收藏  举报