HDU 4757 可持久化trie树

  首先如果给定一些数,询问这些数中哪个数^给定的数的值最大的话,我们可以建立一颗trie树,根连接的两条边分别为0,1,表示二进制下第15位,那么我们可以建立一颗trie树,每一条从根到叶子节点的链表示一个2^16以内的数,开始每个节点的cnt都是0,那么每插入一个元素,将表示这个值的链上所有位置的cnt++,那么对于一个值要求使得^最大,如果这个值的某一位是1,那么我们最好要找到一个为0的数来和他^,那么判断下0儿子的cnt是不是大于0,然后做就好了。

  那么对于这棵树,我们可以先将1为根,然后对于两个点x,y之间的链拆成x,lca和y,lca的两条链,现在问题就转化为了求一个深度递增的链上所有值和给定值的^值最大,那么我们可以建立可持久化trie,每个节点继承父节点的trie树,我们只需要用x的trie树减去lca father的trie树做开始的问题就好了。

  反思:调试的时候输出调试的,然后答案更新的只按照一部分更新的,忘了改回去了。因为这个题没看题,是别人讲的题意,所以没看到多组数据,在这儿一直错= =。

//By BLADEVIL
#include <cstdio>
#include <cstring>
#include <algorithm>
#define maxn 200010

using namespace std;

struct ww {
    int son[2];
    int cnt;
    ww() {
        cnt=0;
        memset(son,0,sizeof son);
    }
}t[maxn<<5];

int n,m,l,tot;
int a[maxn],pre[maxn<<1],other[maxn<<1],last[maxn],que[maxn],dis[maxn],jump[maxn][20];

void connect(int x,int y) {
    pre[++l]=last[x];
    last[x]=l;
    other[l]=y;
}

int lca(int x,int y) {
    if (dis[x]>dis[y]) swap(x,y);
    int dep=dis[y]-dis[x];
    for (int i=0;i<=18;i++) if (dep&(1<<i)) y=jump[y][i];
    if (x==y) return x;
    for (int i=18;i>=0;i--) if (jump[x][i]!=jump[y][i]) x=jump[x][i],y=jump[y][i];
    return jump[x][0];
}

void build(int &x,int dep) {
    if (!x) x=++tot;
    if (dep<0) return ;
    build(t[x].son[0],dep-1); build(t[x].son[1],dep-1);
}    

void insert(int &x,int rot,int dep,int key) {
    if (!x) x=++tot;
    if (dep==-2) return ; 
    if (key&(1<<dep)) {
        insert(t[x].son[1],t[rot].son[1],dep-1,key);
        t[x].son[0]=t[rot].son[0];
    } else {
        insert(t[x].son[0],t[rot].son[0],dep-1,key);
        t[x].son[1]=t[rot].son[1];
    }
    t[x].cnt+=t[rot].cnt+1;
    //printf("|%d %d\n",t[x].cnt,x);
}

int query(int rx,int lx,int key,int dep) {
    if (dep==-2) return 0;
    //printf("%d %d %d %d\n",t[rx].son[1],t[rx].son[0],t[t[rx].son[1]].cnt,t[t[rx].son[0]].cnt);
    int ans=0;
    if (key&(1<<dep)) {
        if (t[t[rx].son[0]].cnt-t[t[lx].son[0]].cnt) {
            ans=1<<dep;
            ans+=query(t[rx].son[0],t[lx].son[0],key,dep-1);
        } else ans=query(t[rx].son[1],t[lx].son[1],key,dep-1);
    } else {
        if (t[t[rx].son[1]].cnt-t[t[lx].son[1]].cnt) {
            ans=1<<dep;
            ans+=query(t[rx].son[1],t[lx].son[1],key,dep-1);
        } else ans=query(t[rx].son[0],t[lx].son[0],key,dep-1);
    }
    //printf("%d\n",ans);
    return ans;
}

void work() {
     memset(t,0,sizeof t);
    memset(last,0,sizeof last);
    memset(dis,0,sizeof dis);
    tot=n; l=0;
    for (int i=1;i<=n;i++) scanf("%d",&a[i]);
    for (int i=1;i<n;i++) {
        int x,y; scanf("%d%d",&x,&y);
        connect(x,y); connect(y,x);
    }
    int head=0,tail=1; que[1]=1; dis[1]=1;
    while (head<tail) {
        int cur=que[++head];
        for (int p=last[cur];p;p=pre[p]) {
            if (dis[other[p]]) continue;
            que[++tail]=other[p]; dis[other[p]]=dis[cur]+1;
        }
    }
    //for (int i=1;i<=n;i++) printf(i==n?"%d\n":"%d ",que[i]);
    jump[1][0]=++tot;
    for (int i=1;i<=n;i++) 
        for (int p=last[que[i]];p;p=pre[p]) 
            if (dis[other[p]]>dis[que[i]]) jump[other[p]][0]=que[i];
    for (int i=1;i<=18;i++)
        for (int j=1;j<=n;j++){
            int cur=que[j];
            jump[cur][i]=jump[jump[cur][i-1]][i-1];
        }
    build(jump[1][0],15);
    for (int i=1;i<=n;i++) insert(que[i],jump[que[i]][0],15,a[que[i]]);
    //for (int i=1;i<=tot;i++) printf("%d %d %d %d\n",i,t[i].son[0],t[i].son[1],t[i].cnt);
    //int x,y; scanf("%d%d",&x,&y); printf("%d\n",lca(x,y));
    while (m--) {
        int x,y,z; scanf("%d%d%d",&x,&y,&z);
        int root=lca(x,y);
        int ans=0;
        ans=max(query(x,jump[root][0],z,15),query(y,jump[root][0],z,15));
        //ans=query(y,jump[root][0],z,15);
        printf("%d\n",ans);
    }
}

int main() {
    while (scanf("%d%d",&n,&m)!=EOF) work(); 
    return 0;
}

 

posted on 2014-04-16 18:29  BLADEVIL  阅读(1156)  评论(0编辑  收藏  举报