[BZOJ4771] 七彩树
题意
给定一棵n个点,每个点带颜色的有根树。点的编号和颜色编号都在1到n,根的编号为1。m次询问,求x子树中与x距离边数不超过k的点中,颜色的种类数目。每个测试点有多组数据。
分析
不妨设1的父亲为0,0包含了所有颜色。不考虑深度限制,对于单独一种颜色c,易知颜色c对于任意两个色c的点之间的路径上的所有询问有且仅有一个贡献,利用树链的并(DFS序+树上差分)即可解决。
考虑深度约束,可以发现按深度对差分序列进行可持久化,使用线段树就可以了。
#include <bits/stdc++.h>
using namespace std;
const int N=500010;
struct Node {
int ls,rs,val;
} t[N*20];
int tot,root[N];
void insert(int&x,int y,int l,int r,int p,int w) {
t[x=++tot]=t[y];
t[x].val+=w; if(l==r) return;
int mid=(l+r)>>1;
if(p<=mid) insert(t[x].ls,t[y].ls,l,mid,p,w);
else insert(t[x].rs,t[y].rs,mid+1,r,p,w);
}
int query(int x,int l,int r,int L,int R) {
if(!x) return 0;
if(L<=l&&r<=R) return t[x].val;
int mid=(l+r)>>1, val=0;
if(L<=mid) val+=query(t[x].ls,l,mid,L,R);
if(mid<R) val+=query(t[x].rs,mid+1,r,L,R);
return val;
}
int n,m,ndp,cnt;
int val[N],siz[N],dfn[N],dep[N],fa[N][20];
vector<int> wuer[N];
queue<int> que;
struct Mogic {
int x;
Mogic(int x=0):x(x){}
bool operator<(const Mogic&d) const {return dfn[x]<dfn[d.x];}
};
set<Mogic> d[N];
int lca(int x,int y) {
if(dep[x]<dep[y]) swap(x,y);
int dif=dep[x]-dep[y];
for(int i=19; ~i; --i) if((dif>>i)&1) x=fa[x][i];
if(x==y) return x;
for(int i=19; ~i; --i) if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void dfs(int x,int d) {
dfn[x]=++cnt;
dep[x]=dep[fa[x][0]=d]+(siz[x]=1);
for(int i=1; (1<<i)<=dep[x]; ++i) fa[x][i]=fa[fa[x][i-1]][i-1];
for(int i=wuer[x].size()-1; ~i; --i) dfs(wuer[x][i],x),siz[x]+=siz[wuer[x][i]];
}
void bfs() {
que.push(1);
while(que.size()) {
int x=que.front(); que.pop();
if(ndp!=dep[x]) {
root[ndp+1]=root[ndp];
ndp++;
}
insert(root[ndp],root[ndp],1,n,dfn[x],1);
set<Mogic>::iterator k=d[val[x]].insert(Mogic(x)).first;
set<Mogic>::iterator k1=k,k2=k; k1--; k2++;
if(k!=d[val[x]].begin()&&k2!=d[val[x]].end()) insert(root[ndp],root[ndp],1,n,dfn[lca(k1->x,k2->x)],1);
if(k!=d[val[x]].begin()) insert(root[ndp],root[ndp],1,n,dfn[lca(k1->x,x)],-1);
if(k2!=d[val[x]].end()) insert(root[ndp],root[ndp],1,n,dfn[lca(x,k2->x)],-1);
for(int i=wuer[x].size()-1; ~i; --i) que.push(wuer[x][i]);
}
}
int main() {
// freopen("a.in","r",stdin);
int T,last;
scanf("%d",&T);
while(T--) {
scanf("%d%d",&n,&m);
tot=cnt=last=ndp=0;
for(int i=1; i<=n; ++i) {
memset(fa[i],0,sizeof fa[i]);
scanf("%d",&val[i]);
wuer[i].clear();
d[i].clear();
root[i]=0;
}
for(int i=2,x; i<=n; ++i) {
scanf("%d",&x);
wuer[x].push_back(i);
}
dfs(1,0);
bfs();
for(int x,k; m--; ) {
scanf("%d%d",&x,&k); x^=last,k^=last;
printf("%d\n",last=query(root[min(dep[x]+k,ndp)],1,n,dfn[x],dfn[x]+siz[x]-1));
}
}
return 0;
}