P4689 [Ynoi2016] 这是我自己的发明

对根为 \(1\) 的有点权的树支持如下操作:

  • 换根
  • 给定 \(x,y\),求 \(\displaystyle \sum_{u\in\operatorname{subtree}(x)}\sum_{v\in\operatorname{subtree}(y)}[a_u=a_v]\)

\(n\le 10^5\)\(m\le 5\times 10^5\)\(a_i\le 10^9\)


把子树计数转到 dfs 序上,发现这个就是带换根的 P5268 [SNOI2017] 一个简单的询问

先看不换根怎么办。记 \(\{c_n\}\) 为 dfs 序上的 \(\{a_n\}\)\(\displaystyle f(a,b,c,d)=\sum_{u\in[a,b]}\sum_{v\in[c,d]}[c_u=c_v]\)

差分,记 \(\displaystyle g(a,b)=\sum_{x}cnt_x[1,a]\cdot cnt_x[1,b]\),那么 \(f(a,b,c,d)=g(b,d)-g(b,c-1)-g(a-1,d)+g(a-1,c-1)\),拆成四个询问,可以莫队解决。

处理换根:

  • \(rt=u\),则 \(u\) 对应区间 \([1,n]\)

  • \(rt\) 在子树 \(u\) 外,则 \(u\) 对应区间 \([st_u,ed_u]\)

  • 否则,令 \(v\) 为链 \(u\rightarrow rt\) 上离 \(u\) 最近的点,\(u\) 对应区间 \([1,st_v-1]\cup[ed_v+1,n]\)

发现如果 \(u\)\(v\) 均对应两端区间,询问量可以到 \(16m\),不是很理想,考虑减少询问数。

  • \(u,v\) 均对应一段区间

上面写过了。

  • \(u\) 对应一段区间,\(v\) 对应两端区间

也就是

\[f(l_1,r_1,1,l_2)+f(l_1,r_1,r_2,n) \]

\[=g(r_1,l_2)-g(l_1-1,l_2)-g(r_1,r_2-1)+g(l_1-1,r_2-1)+g(r_1,n)-g(l_1-1,n) \]

\(h(x)=g(x,n)\),即 \(g(r_1,l_2)-g(l_1-1,l_2)-g(r_1,r_2-1)+g(l_1-1,r_2-1)+h(r_1)-h(l_1-1)\)\(h\) 容易预处理。

发现只用拆 \(4\) 个询问。

  • \(u,v\) 均对应两端区间

\[f(1,l_1,1,l_2)+f(r_1,n,1,l_2)+f(1,l_1,r_2,n)+f(r_1,n,r_2,n) \]

也能推出来是

\[g(l_1,l_2)-g(r_1-1,l_2)-g(l_1,r_2-1)+g(r_1-1,r_2-1)+h(n)+h(l_1)+h(l_2)-h(r_1-1)-h(r_2-1) \]

也只用拆 \(4\) 个询问。

\(B=\dfrac{n}{\sqrt{m'}}\),时间复杂度 \(O(n\sqrt{m})\)

#include<bits/stdc++.h>
#define ll long long
#define N 100010
#define M 500010
using namespace std;
int read(){
	int x=0,w=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
	while(isdigit(ch))x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return x*w;
}
int n,m,B,a[N];
int b[N],len;
struct Q{
	int l,r,id,fl;
	bool operator<(const Q &x)const{
		if(l/B!=x.l/B)return l<x.l;
		return (l/B)&1?r>x.r:r<x.r;
	}
}q[M<<2];int tq;
void ins(int l,int r,int fl){
	if(l>r)swap(l,r);
	if(l<1||r>n)return;
	q[++tq]={l,r,m,fl};
}
vector<int>e[N];
int st[N],ed[N],tim,c[N];
int fa[N][17],dep[N];
int fath(int x,int k){
	for(int i=16;~i;i--)
		if((k>>i)&1)x=fa[x][i];
	return x;
}
void dfs(int u,int f){
	st[u]=++tim,c[tim]=a[u];
	fa[u][0]=f,dep[u]=dep[f]+1;
	for(int i=1;i<17;i++)fa[u][i]=fa[fa[u][i-1]][i-1];
	for(int v:e[u])
		if(v!=f)dfs(v,u);
	ed[u]=tim;
}
void get(int u,int rt,int &l,int &r,int &f){
	if(u==rt)return l=1,r=n,f=false,void();
	if(st[rt]<st[u]||ed[rt]>ed[u])
		return l=st[u],r=ed[u],f=false,void();
	int v=fath(rt,dep[rt]-dep[u]-1);
	l=st[v]-1,r=ed[v]+1,f=true;
}
int cntl[N],cntr[N],cntn[N];ll sum,h[N],ans[M];
void llht(int x){sum-=cntr[x],cntl[x]--;}
void rrht(int x){sum+=cntl[x],cntr[x]++;}
void lrht(int x){sum+=cntr[x],cntl[x]++;}
void rlht(int x){sum-=cntl[x],cntr[x]--;}
int main(){
	n=read();int T=read();
	for(int i=1;i<=n;i++)a[i]=b[i]=read();
	sort(b+1,b+1+n),len=unique(b+1,b+1+n)-b-1;
	for(int i=1;i<=n;i++)
		a[i]=lower_bound(b+1,b+1+len,a[i])-b;
	for(int i=1,u,v;i<n;i++){
		u=read(),v=read();
		e[u].push_back(v),e[v].push_back(u);
	}
	dfs(1,0);
	for(int i=1;i<=n;i++)cntn[c[i]]++;
	for(int i=1;i<=n;i++)
		h[i]=h[i-1]+cntn[c[i]];
	for(int opt,rt=1,u,v,l1,r1,l2,r2,f1,f2;T;T--){
		opt=read();
		if(opt==1){rt=read();continue;}
		u=read(),v=read(),m++;
		get(u,rt,l1,r1,f1),get(v,rt,l2,r2,f2);
		if(!f2)swap(u,v),swap(l1,l2),swap(r1,r2),swap(f1,f2);
		if(!f1&&!f2)
			ins(r1,r2,1),ins(l1-1,r2,-1),ins(r1,l2-1,-1),ins(l1-1,l2-1,1);
		else if(!f1){
			ans[m]=h[r1]-h[l1-1];
			ins(r1,l2,1),ins(l1-1,l2,-1),ins(r1,r2-1,-1),ins(l1-1,r2-1,1);
		}
		else{
			ans[m]=h[n]+h[l2]+h[l1]-h[r1-1]-h[r2-1];
			ins(l1,l2,1),ins(r1-1,l2,-1),ins(l1,r2-1,-1),ins(r1-1,r2-1,1);
		}
	}
	B=ceil(n/sqrt(tq));
	sort(q+1,q+1+tq);
	for(int i=1,l=0,r=0;i<=tq;i++){
		while(l>q[i].l)llht(c[l--]);
		while(r<q[i].r)rrht(c[++r]);
		while(l<q[i].l)lrht(c[++l]);
		while(r>q[i].r)rlht(c[r--]);
		ans[q[i].id]+=q[i].fl*sum;
	}
	for(int i=1;i<=m;i++)
		printf("%lld\n",ans[i]);

	return 0;
}
posted @ 2024-02-25 20:20  SError  阅读(5)  评论(0编辑  收藏  举报