【题解】[HDU 5709] Claris Loves Painting【线段树合并】
题意
一棵有根树,点有颜色。多次询问:\(x\) 的子树中,与 \(x\) 距离不超过 \(d\) 的所有点中,有多少种不同的颜色。多组数据。
\(n,m\leq 10^5\),\(\sum n,\sum m\leq 5\times 10^5\)
题解
对于每个结点 \(u\) 维护两个线段树:
- 第一个以深度为下标,记录:在 \(u\) 的子树中,只考虑每种颜色首次出现的位置,深度为 \(i\) 的颜色的数量。该线段树还需维护区间和;
- 第二个以颜色为下标,记录:在 \(u\) 的子树中,颜色 \(i\) 出现的最浅位置。
自底向上做线段树合并,先合并第一个,再合并第二个,并在第二个出现重复的颜色时修改第一个,只保留较浅的点。合并时还要可持久化。
#include<bits/stdc++.h>
using namespace std;
int getint(){
int a=0;
char c=getchar();
while(c<'0'||c>'9')c=getchar();
while(c>='0'&&c<='9'){
a=a*10+c-'0';
c=getchar();
}
return a;
}
const int N=1e5+10;
int n,m;
int ch1[N<<6][2],ch2[N<<6][2],sum[N<<6],mn[N<<6],cnt1,cnt2;
void pushup1(int x){ sum[x]=sum[ch1[x][0]]+sum[ch1[x][1]]; }
void mdf1(int p,int v,int x,int nl,int nr){
if(nl==nr){
sum[x]+=v;
return;
}
int mid=nl+nr>>1;
if(p<=mid){
memcpy(ch1[cnt1+1],ch1[ch1[x][0]],sizeof(int)*2);
sum[cnt1+1]=sum[ch1[x][0]];
ch1[x][0]=++cnt1;
mdf1(p,v,ch1[x][0],nl,mid);
}else{
memcpy(ch1[cnt1+1],ch1[ch1[x][1]],sizeof(int)*2);
sum[cnt1+1]=sum[ch1[x][1]];
ch1[x][1]=++cnt1;
mdf1(p,v,ch1[x][1],mid+1,nr);
}
pushup1(x);
}
void mdf2(int p,int v,int x,int nl,int nr){
if(nl==nr){
mn[x]=v;
return;
}
int mid=nl+nr>>1;
if(p<=mid){
memcpy(ch2[cnt2+1],ch2[ch2[x][0]],sizeof(int)*2);
ch2[x][0]=++cnt2;
mdf2(p,v,ch2[x][0],nl,mid);
}else{
memcpy(ch2[cnt2+1],ch2[ch2[x][1]],sizeof(int)*2);
ch2[x][1]=++cnt2;
mdf2(p,v,ch2[x][1],mid+1,nr);
}
}
int mer1(int l,int r,int x,int y){
if(!(x&&y))return x|y;
if(l==r){
++cnt1;
sum[cnt1]=sum[x]+sum[y];
return cnt1;
}
int t=++cnt1,mid=l+r>>1;
ch1[t][0]=mer1(l,mid,ch1[x][0],ch1[y][0]);
ch1[t][1]=mer1(mid+1,r,ch1[x][1],ch1[y][1]);
pushup1(t);
return t;
}
int mer2(int l,int r,int x,int y,int rt1){
if(!(x&&y))return x|y;
if(l==r){
++cnt2;
int p=mn[x],q=mn[y];
if(p>q)swap(p,q);
mn[cnt2]=p;
mdf1(q,-1,rt1,1,n);
// mdf1(p,1,rt1,1,n);
return cnt2;
}
int t=++cnt2,mid=l+r>>1;
ch2[t][0]=mer2(l,mid,ch2[x][0],ch2[y][0],rt1);
ch2[t][1]=mer2(mid+1,r,ch2[x][1],ch2[y][1],rt1);
return t;
}
int query(int l,int r,int x,int nl,int nr){
if(!x)return 0;
if(nr<l||nl>r)return 0;
if(l<=nl&&nr<=r)return sum[x];
int mid=nl+nr>>1;
return query(l,r,ch1[x][0],nl,mid)+query(l,r,ch1[x][1],mid+1,nr);
}
int rt1[N],rt2[N];
struct bian{
int e,n;
};
bian b[N];
int s[N],tot=0;
void add(int x,int y){
tot++;
b[tot].e=y;
b[tot].n=s[x];
s[x]=tot;
}
int dep[N];
int col[N],fa[N];
int main(){
int T=getint();
while(T--){
n=getint(),m=getint();
for(int i=1;i<=n;i++)col[i]=getint();
for(int i=2;i<=n;i++)fa[i]=getint(),add(fa[i],i);
dep[1]=1;for(int i=2;i<=n;i++)dep[i]=dep[fa[i]]+1;
for(int x=n;x;--x){
rt1[x]=++cnt1;
mdf1(dep[x],1,rt1[x],1,n);
rt2[x]=++cnt2;
mdf2(col[x],dep[x],rt2[x],1,n);
for(int i=s[x];i;i=b[i].n){
int v=b[i].e;
rt1[x]=mer1(1,n,rt1[x],rt1[v]);
rt2[x]=mer2(1,n,rt2[x],rt2[v],rt1[x]);
}
}
int lastans=0;
while(m--){
int x=getint(),y=getint();
x^=lastans;
y^=lastans;
lastans=query(dep[x],min(dep[x]+y,n),rt1[x],1,n);
printf("%d\n",lastans);
}
memset(ch1,0,sizeof(ch1[0])*(cnt1+2));
memset(ch2,0,sizeof(ch2[0])*(cnt2+2));
memset(sum,0,sizeof(sum[0])*(cnt1+2));
memset(mn,0,sizeof(mn[0])*(cnt2+2));
cnt1=cnt2=0;
memset(s,0,sizeof(s[0])*(n+1));
memset(b,0,sizeof(b[0])*(tot+1));
tot=0;
}
}