SPOJ 10707 COT2 - Count on a tree II
思路
树上莫队的题目
每次更新(u1,u2)和(v1,v2)(不包括lca)的路径,最后单独统计LCA即可
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <stack>
#include <cmath>
using namespace std;
int v[100100*2],fir[100100],nxt[100100*2],cnt=0,w_p[100100],b[100100],n,m,tx,sum,jump[100100][20],dep[100100],in_ans[100100],sz,belong[100100],block_cnt,ans[100100],color[100100],U,V;
stack<int> S;
void addedge(int ui,int vi){
++cnt;
v[cnt]=vi;
nxt[cnt]=fir[ui];
fir[ui]=cnt;
}
void dfs(int u,int f){
jump[u][0]=f;
for(int i=1;i<20;i++)
jump[u][i]=jump[jump[u][i-1]][i-1];
dep[u]=dep[f]+1;
int t=S.size();
for(int i=fir[u];i;i=nxt[i]){
if(v[i]==f)
continue;
dfs(v[i],u);
if(S.size()-t>=sz){
++block_cnt;
while(S.size()>t){
belong[S.top()]=block_cnt;
S.pop();
}
}
}
S.push(u);
}
void init(void){
sz=sqrt(n);
dfs(1,0);
}
int lca(int x,int y){
if(dep[x]<dep[y])
swap(x,y);
for(int i=19;i>=0;i--)
if(dep[jump[x][i]]>=dep[y])
x=jump[x][i];
if(x==y)
return x;
for(int i=19;i>=0;i--)
if(jump[x][i]!=jump[y][i])
x=jump[x][i],y=jump[y][i];
return jump[x][0];
}
void modi_point(int x){
if(in_ans[x]){//erase
color[w_p[x]]--;
if(!color[w_p[x]])
sum--;
}
else{
if(!color[w_p[x]])
sum++;
color[w_p[x]]++;
}
in_ans[x]^=1;
}
void move_path(int x,int y){//(x,y) except LCA(x,y)
if(dep[x]<dep[y])
swap(x,y);
while(dep[x]>dep[y]){
modi_point(x);
x=jump[x][0];
}
while(x!=y){
modi_point(x);
modi_point(y);
x=jump[x][0];
y=jump[y][0];
}
}
struct Query{
int u,v,id;
bool operator < (const Query &b) const{
return (belong[u]==belong[b.u])?belong[v]<belong[b.v]:belong[u]<belong[b.u];
}
}Q[100100];
int main(){
scanf("%d %d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&w_p[i]),b[i]=w_p[i];
sort(b+1,b+n+1);
tx=unique(b+1,b+n+1)-(b+1);
for(int i=1;i<=n;i++)
w_p[i]=lower_bound(b+1,b+n+1,w_p[i])-b;
for(int i=1;i<n;i++){
int a,b;
scanf("%d %d",&a,&b);
addedge(a,b);
addedge(b,a);
}
init();
for(int i=1;i<=m;i++){
scanf("%d %d",&Q[i].u,&Q[i].v);
Q[i].id=i;
}
sort(Q+1,Q+m+1);
U=V=1;
for(int i=1;i<=m;i++){
move_path(U,Q[i].u);
move_path(V,Q[i].v);
U=Q[i].u;
V=Q[i].v;
int Lca=lca(Q[i].u,Q[i].v);
modi_point(Lca);
ans[Q[i].id]=sum;
modi_point(Lca);
}
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}