树上莫队
20231012
树上莫队
由于联考考到,又直接爆0,于是来学习。
树上莫队——把莫队放到树上。
但我是真的不知道把莫队怎么放到树上。。。
于是我们考虑一个东西叫做欧拉序,
就是再 dfs 的时候在进栈和出栈的地方都记录一下。
而在区间查询的时候,我们只对区间出现一次的数统计答案,
用一个数组维护即可,就是每次改变 \(\oplus 1\)。
而对于每一次 \(u,v\) ,
如果在一个子树之内,我们就从 \(st[u]\) 查到 \(st[v]\),
反之,从 \(ed[u]\) 查到 \(st[v]\) 即可。
注意对于不在同一棵子树中的,
需要把 lca 的贡献加上。
SP10707 模板题
#include <bits/stdc++.h>
using namespace std;
const int N=4e4+5,M=1e5+5;
int bl,n,m,a[N],b[N],len,l,r,vis[N],c[N],cnt=0,top[N],st[N],ans[M],ed[N],idx=0,dep[N],sq[M],fa[N],son[N],siz[N],head[N],tot=0;
struct edge{
int v,nxt;
}e[N<<1];
struct node{
int l,r,u,v,bl,id,lca;
bool operator <(const node &rhs) const{
if(bl!=rhs.bl) return bl<rhs.bl;
return r<rhs.r;
}
}q[M];
int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-') f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
void print(int x){
int p[15],tmp=0;
if(x==0) putchar('0');
if(x<0) putchar('-'),x=-x;
while(x){
p[++tmp]=x%10;
x/=10;
}
for(int i=tmp;i>=1;i--) putchar(p[i]+'0');
putchar('\n');
}
void add(int u,int v){
e[++tot]=(edge){v,head[u]};
head[u]=tot;
e[++tot]=(edge){u,head[v]};
head[v]=tot;
}
void dfs1(int u,int pre){
fa[u]=pre,siz[u]=1;son[u]=-1;dep[u]=dep[pre]+1;
st[u]=++idx;sq[idx]=u;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==pre) continue;
dfs1(v,u);
if(son[u]==-1||siz[son[u]]<siz[v]) son[u]=v;
siz[u]+=siz[v];
}
ed[u]=++idx;sq[idx]=u;
}
void dfs2(int u,int pre){
top[u]=pre;
if(son[u]==-1) return ;
dfs2(son[u],pre);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
int lca(int u,int v){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
return u;
}
void upd(int i){cnt+=(++c[a[i]]==1);}
void del(int i){cnt-=(--c[a[i]]==0);}
void wrk(int i){
if(vis[i]) del(i);
else upd(i);
vis[i]^=1;
}
int main(){
/*2023.10.12 H_W_Y SP10707 COT2 - Count on a tree II 树上莫队*/
n=read();m=read();
for(int i=1;i<=n;i++) a[i]=read(),b[i]=a[i];
sort(b+1,b+n+1);
len=unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+len+1,a[i])-b;
for(int i=1,u,v;i<n;i++) u=read(),v=read(),add(u,v);
dfs1(1,0);dfs2(1,1);bl=sqrt((n<<1));
for(int i=1;i<=m;i++){
q[i].u=read();q[i].v=read();q[i].id=i;
int &u=q[i].u,&v=q[i].v;q[i].lca=lca(u,v);
if(st[u]>st[v]) swap(u,v);
if(u==q[i].lca) q[i].l=st[u],q[i].r=st[v],q[i].bl=st[u]/bl,q[i].lca=0;
else q[i].l=ed[u],q[i].r=st[v],q[i].bl=ed[u]/bl;
}
sort(q+1,q+m+1);l=1,r=0;
for(int i=1;i<=m;i++){
while(q[i].r>r) wrk(sq[++r]);
while(q[i].l<l) wrk(sq[--l]);
while(q[i].r<r) wrk(sq[r--]);
while(q[i].l>l) wrk(sq[l++]);
if(q[i].lca) wrk(q[i].lca);
ans[q[i].id]=cnt;
if(q[i].lca) wrk(q[i].lca);
}
for(int i=1;i<=m;i++) print(ans[i]);
return 0;
}