BZOJ 4771 主席树+倍增+set

思路:

因为有深度的限制,并且我们是在线段树上维护权值,所以我们把点按照dep排序,然后一个一个修改...主席树的下标就是dfs序,子树的查询就是区间查询...

但是发现这样怎么去维护LCA呢...因为要求有序,所以我们可以用set来维护相同颜色的节点...如果把一个点加入集合之后这个点前驱为x,后继为y,那么我们去修正,把xy的LCA+1,然后x和当前点的LCA-1,当前点和y的LCA-1...

from neighthorn

 

//By SiriusRen
#include <set> 
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=100050;
int n,m,cases,v[N],next[N],first[N],tot,fa[N][20],dfn[N],cnt,root[N],lst[N],dep[N];
int tree[N*50],lson[N*50],rson[N*50],xx,yy,ans;
struct Node{int x,deep,col;}node[N];
bool cmp(Node a,Node b){return a.deep<b.deep;}
bool cmp2(Node a,Node b){return a.x<b.x;}
struct Cmp{bool operator()(Node a,Node b){return dfn[a.x]<dfn[b.x];}};
set<Node,Cmp>s[N];set<Node,Cmp>::iterator it,it2,it3;
void add(int x,int y){v[tot]=y,next[tot]=first[x],first[x]=tot++;}
void dfs(int x){
    dfn[x]=++cnt;
    for(int i=first[x];~i;i=next[i])if(v[i]!=fa[x][0])
        dep[v[i]]=node[v[i]].deep=node[x].deep+1,dfs(v[i]);
    lst[x]=cnt;
}
void insert(int l,int r,int &pos,int last,int num,int wei){
    pos=++cnt,tree[pos]=tree[last]+wei;
    if(l==r)return;
    int mid=(l+r)>>1;
    if(mid<num)lson[pos]=lson[last],insert(mid+1,r,rson[pos],rson[last],num,wei);
    else rson[pos]=rson[last],insert(l,mid,lson[pos],lson[last],num,wei);
}
int query(int l,int r,int pos,int L,int R){
    if(l>=L&&r<=R)return tree[pos];
    int mid=(l+r)>>1;
    if(mid<L)return query(mid+1,r,rson[pos],L,R);
    else if(mid>=R)return query(l,mid,lson[pos],L,R);
    else return query(l,mid,lson[pos],L,R)+query(mid+1,r,rson[pos],L,R);
}
int lca(int x,int y){
    if(dep[x]<dep[y])swap(x,y);
    for(int i=19;~i;i--)if(dep[x]-(1<<i)>=dep[y])x=fa[x][i];
    if(x==y)return x;
    for(int i=19;~i;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
int main(){
    scanf("%d",&cases);
    while(cases--){
        memset(first,-1,sizeof(first)),ans=tot=cnt=0;
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++)scanf("%d",&node[i].col),node[i].x=i;
        for(int i=2;i<=n;i++)scanf("%d",&fa[i][0]),add(fa[i][0],i);
        for(int j=1;j<=19;j++)for(int i=1;i<=n;i++)fa[i][j]=fa[fa[i][j-1]][j-1];
        dep[1]=node[1].deep=1,dfs(1),cnt=0,sort(node+1,node+1+n,cmp);
        for(int i=1;i<=n;i++){
            int lst=node[i-1].deep,now=node[i].deep;
            insert(1,n,root[now],root[lst],dfn[node[i].x],1);
            s[node[i].col].insert(node[i]),it=s[node[i].col].find(node[i]);it2=it,++it2;
            if(s[node[i].col].size()>2&&it!=s[node[i].col].begin()&&it2!=s[node[i].col].end())
                it2=it,it2--,it3=it,it3++,insert(1,n,root[now],root[now],dfn[lca((*it2).x,(*it3).x)],1);
            if(it!=s[node[i].col].begin())it2=it,it2--,insert(1,n,root[now],root[now],dfn[lca((*it2).x,(*it).x)],-1);
            it2=it,++it2;
            if(it2!=s[node[i].col].end())insert(1,n,root[now],root[now],dfn[lca((*it2).x,(*it).x)],-1);
        }
        while(m--){
            scanf("%d%d",&xx,&yy),xx^=ans,yy^=ans;
            printf("%d\n",ans=query(1,n,root[min(dep[xx]+yy,node[n].deep)],dfn[xx],lst[xx]));
        }
        for(int i=1;i<=n;i++)s[i].clear(),root[i]=0;
    }
}

 

posted @ 2017-03-21 08:26  SiriusRen  阅读(430)  评论(0编辑  收藏  举报