BZOJ4771 七彩树

貌似是道经典题(?强行增加代码量就是了= =

考虑维护树链的并,用线段树维护差分就可以了

差分:

关键点+1

两个相邻:LCA-1

再考虑深度限制,直接上个主席树就好了

因为主席树要排序后插入 所以会对上面的差分带来后效性

如果两边都有点的话 LCA(左,右)+1

突然发现PKUWC的D1T2其实就是这个东西...

考虑是树上联通块 大概就是对于边计算这个,对于点也计算一个 然后就可以FFT以后容斥...

行吧当时的水平写出虚树+链剖我很知足了= =

//Love and Freedom.
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<set>
#include<cassert>
#define ll long long
#define inf 20021225
#define N 100100
#define pb push_back
#define IT set<int>::iterator
using namespace std;
int read()
{
    int s=0,f=1; char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();}
    while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
    return f*s;
}
struct edge{int to,lt;}e[N];
int in[N],cnt,f[N][18],dep[N],dfn[N],tms,idfn[N],sz[N];
void add(int x,int y)
{
    e[++cnt].to=y; e[cnt].lt=in[x]; in[x]=cnt;
}
void dfs(int x)
{
    dfn[x]=++tms; idfn[tms]=x; sz[x]=1;
    for(int i=1;i<18;i++)    f[x][i]=f[f[x][i-1]][i-1];
    for(int i=in[x];i;i=e[i].lt)
    {
        int y=e[i].to; dep[y]=dep[x]+1;
        f[y][0]=x; dfs(y); sz[x]+=sz[y];
    }
}
int LCA(int x,int y)
{
    if(dep[x]<dep[y])    swap(x,y);
    int len=dep[x]-dep[y];
    for(int i=0;i<18;i++)    if(len>>i&1)
        x=f[x][i];
    if(x==y)    return x;
    for(int i=17;~i;i--)    if(f[x][i]!=f[y][i])
        x=f[x][i],y=f[y][i];
    //assert(x>1&&y>1);
    return f[x][0];
}
struct node{int sum,ls,rs;}t[N*80];
int poi,n,c[N],id[N];
set<int> cc[N]; int rt[N];
void insert(int x,int &y,int l,int r,int p,int v)
{
    y=++poi; t[y]=t[x]; t[y].sum+=v;
    if(l==r)    return; int mid=l+r>>1;
    if(p<=mid)    insert(t[x].ls,t[y].ls,l,mid,p,v);
    else    insert(t[x].rs,t[y].rs,mid+1,r,p,v);
}
int query(int x,int l,int r,int LL,int RR)
{
    if(!x || (LL<=l&&RR>=r))    return t[x].sum;
    int mid=l+r>>1,ans=0;
    if(LL<=mid)    ans+=query(t[x].ls,l,mid,LL,RR);
    if(RR>mid)    ans+=query(t[x].rs,mid+1,r,LL,RR);
    return ans;
}
void clear()
{
    memset(rt,0,n+1<<2);
    memset(in,0,n+1<<2);
    cnt=tms=poi=0;
    for(int i=1;i<=n;i++)
        cc[i].clear();
}
bool azy(int a,int b){return dep[a]<dep[b];}
void solve()
{
    clear(); int lastans=0;
    n=read(); int q=read(),ff,mxd=0;
    for(int i=1;i<=n;i++)
        c[i]=read(),id[i]=i;
    for(int i=2;i<=n;i++)    ff=read(),add(ff,i);
    dep[1]=1; dfs(1); sort(id+1,id+n+1,azy);
    for(int i=1;i<=n;i++)
    {
        int x=id[i],lst=0,nxt=0,cl=c[x],dd=dep[x],gg=dep[id[i-1]];
        mxd=max(mxd,dep[x]);// assert(x);
        IT ls=cc[cl].lower_bound(dfn[x]),
           nx=cc[cl].upper_bound(dfn[x]);
        insert(rt[gg],rt[dd],1,n,dfn[x],1);
        if(ls!=cc[cl].begin())
        {
            lst=idfn[*(--ls)];
            insert(rt[dd],rt[dd],1,n,dfn[LCA(lst,x)],-1);
        }
        if(nx!=cc[cl].end())
        {
            nxt=idfn[*nx];
            insert(rt[dd],rt[dd],1,n,dfn[LCA(nxt,x)],-1);
        }
        if(lst&&nxt)
            insert(rt[dd],rt[dd],1,n,dfn[LCA(nxt,lst)],1);
        cc[cl].insert(dfn[x]);
    }
    while(q--)
    {
        int x=read()^lastans,d=read()^lastans;
        //assert(x<=n);
        printf("%d\n",lastans=query(rt[min(dep[x]+d,mxd)],1,n,dfn[x],dfn[x]+sz[x]-1));
    }
}
int main()
{
    int T=read();
    while(T--)
        solve();
    return 0;
}
View Code

 

posted @ 2019-10-10 16:11  寒雨微凝  阅读(274)  评论(0编辑  收藏  举报