P4689 [Ynoi2016] 这是我自己的发明
对根为 \(1\) 的有点权的树支持如下操作:
- 换根
- 给定 \(x,y\),求 \(\displaystyle \sum_{u\in\operatorname{subtree}(x)}\sum_{v\in\operatorname{subtree}(y)}[a_u=a_v]\)。
\(n\le 10^5\),\(m\le 5\times 10^5\),\(a_i\le 10^9\)。
把子树计数转到 dfs 序上,发现这个就是带换根的 P5268 [SNOI2017] 一个简单的询问。
先看不换根怎么办。记 \(\{c_n\}\) 为 dfs 序上的 \(\{a_n\}\),\(\displaystyle f(a,b,c,d)=\sum_{u\in[a,b]}\sum_{v\in[c,d]}[c_u=c_v]\)。
差分,记 \(\displaystyle g(a,b)=\sum_{x}cnt_x[1,a]\cdot cnt_x[1,b]\),那么 \(f(a,b,c,d)=g(b,d)-g(b,c-1)-g(a-1,d)+g(a-1,c-1)\),拆成四个询问,可以莫队解决。
处理换根:
-
若 \(rt=u\),则 \(u\) 对应区间 \([1,n]\)。
-
若 \(rt\) 在子树 \(u\) 外,则 \(u\) 对应区间 \([st_u,ed_u]\)。
-
否则,令 \(v\) 为链 \(u\rightarrow rt\) 上离 \(u\) 最近的点,\(u\) 对应区间 \([1,st_v-1]\cup[ed_v+1,n]\)。
发现如果 \(u\) 和 \(v\) 均对应两端区间,询问量可以到 \(16m\),不是很理想,考虑减少询问数。
- \(u,v\) 均对应一段区间
上面写过了。
- \(u\) 对应一段区间,\(v\) 对应两端区间
也就是
记 \(h(x)=g(x,n)\),即 \(g(r_1,l_2)-g(l_1-1,l_2)-g(r_1,r_2-1)+g(l_1-1,r_2-1)+h(r_1)-h(l_1-1)\),\(h\) 容易预处理。
发现只用拆 \(4\) 个询问。
- \(u,v\) 均对应两端区间
也能推出来是
也只用拆 \(4\) 个询问。
取 \(B=\dfrac{n}{\sqrt{m'}}\),时间复杂度 \(O(n\sqrt{m})\)。
#include<bits/stdc++.h>
#define ll long long
#define N 100010
#define M 500010
using namespace std;
int read(){
int x=0,w=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
while(isdigit(ch))x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*w;
}
int n,m,B,a[N];
int b[N],len;
struct Q{
int l,r,id,fl;
bool operator<(const Q &x)const{
if(l/B!=x.l/B)return l<x.l;
return (l/B)&1?r>x.r:r<x.r;
}
}q[M<<2];int tq;
void ins(int l,int r,int fl){
if(l>r)swap(l,r);
if(l<1||r>n)return;
q[++tq]={l,r,m,fl};
}
vector<int>e[N];
int st[N],ed[N],tim,c[N];
int fa[N][17],dep[N];
int fath(int x,int k){
for(int i=16;~i;i--)
if((k>>i)&1)x=fa[x][i];
return x;
}
void dfs(int u,int f){
st[u]=++tim,c[tim]=a[u];
fa[u][0]=f,dep[u]=dep[f]+1;
for(int i=1;i<17;i++)fa[u][i]=fa[fa[u][i-1]][i-1];
for(int v:e[u])
if(v!=f)dfs(v,u);
ed[u]=tim;
}
void get(int u,int rt,int &l,int &r,int &f){
if(u==rt)return l=1,r=n,f=false,void();
if(st[rt]<st[u]||ed[rt]>ed[u])
return l=st[u],r=ed[u],f=false,void();
int v=fath(rt,dep[rt]-dep[u]-1);
l=st[v]-1,r=ed[v]+1,f=true;
}
int cntl[N],cntr[N],cntn[N];ll sum,h[N],ans[M];
void llht(int x){sum-=cntr[x],cntl[x]--;}
void rrht(int x){sum+=cntl[x],cntr[x]++;}
void lrht(int x){sum+=cntr[x],cntl[x]++;}
void rlht(int x){sum-=cntl[x],cntr[x]--;}
int main(){
n=read();int T=read();
for(int i=1;i<=n;i++)a[i]=b[i]=read();
sort(b+1,b+1+n),len=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;i++)
a[i]=lower_bound(b+1,b+1+len,a[i])-b;
for(int i=1,u,v;i<n;i++){
u=read(),v=read();
e[u].push_back(v),e[v].push_back(u);
}
dfs(1,0);
for(int i=1;i<=n;i++)cntn[c[i]]++;
for(int i=1;i<=n;i++)
h[i]=h[i-1]+cntn[c[i]];
for(int opt,rt=1,u,v,l1,r1,l2,r2,f1,f2;T;T--){
opt=read();
if(opt==1){rt=read();continue;}
u=read(),v=read(),m++;
get(u,rt,l1,r1,f1),get(v,rt,l2,r2,f2);
if(!f2)swap(u,v),swap(l1,l2),swap(r1,r2),swap(f1,f2);
if(!f1&&!f2)
ins(r1,r2,1),ins(l1-1,r2,-1),ins(r1,l2-1,-1),ins(l1-1,l2-1,1);
else if(!f1){
ans[m]=h[r1]-h[l1-1];
ins(r1,l2,1),ins(l1-1,l2,-1),ins(r1,r2-1,-1),ins(l1-1,r2-1,1);
}
else{
ans[m]=h[n]+h[l2]+h[l1]-h[r1-1]-h[r2-1];
ins(l1,l2,1),ins(r1-1,l2,-1),ins(l1,r2-1,-1),ins(r1-1,r2-1,1);
}
}
B=ceil(n/sqrt(tq));
sort(q+1,q+1+tq);
for(int i=1,l=0,r=0;i<=tq;i++){
while(l>q[i].l)llht(c[l--]);
while(r<q[i].r)rrht(c[++r]);
while(l<q[i].l)lrht(c[++l]);
while(r>q[i].r)rlht(c[r--]);
ans[q[i].id]+=q[i].fl*sum;
}
for(int i=1;i<=m;i++)
printf("%lld\n",ans[i]);
return 0;
}