树上主席树 - 查询树链上第K大

Description

给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。

Input

第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。

Output

M行,表示每个询问的答案。最后一个询问不输出换行符

Sample Input

8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2

Sample Output

2
8
9
105
7

HINT

N,M<=100000
暴力自重。。。
 
题意 : 给你一棵树,每次询问任意两点间第 K 小的元素,强调在线。
思路分析 :还是利用主席树,在树上建立主席树,从根节点到当前节点建立线段树,借助 lca, 然后比如要查询 u, v 这段区间,要先找到u, v 之间的公共祖先 lca
代码示例 :
#define ll long long
const int maxn = 1e5+5;
const int mod = 1e9+7;
const double eps = 1e-9;
const double pi = acos(-1.0);
const int inf = 0x3f3f3f3f;

int n, m;
int pre[maxn], rank[maxn];
vector<int>ve[maxn];
int dep[maxn];
int grand[maxn][22];
int root[maxn];
int cnt, ss;
struct node
{
    int l, r;
    int sum;
}t[maxn*20];

void update(int num, int &rt, int l, int r){
    t[cnt++] = t[rt];
    rt = cnt-1;
    t[rt].sum++;
    
    if (l == r) return;
    int m = (l+r)>>1;
    if (num <= m) update(num, t[rt].l, l, m);
    else update(num, t[rt].r, m+1, r);
}

void dfs(int x, int fa){
    for(int i = 1; i <= 20; i++){
        grand[x][i] = grand[grand[x][i-1]][i-1];
    }
    int num = lower_bound(rank+1, rank+ss, pre[x])-rank;
    root[x] = root[fa];
    update(num, root[x], 1, n);
    for(int i = 0; i < ve[x].size(); i++){
        int to = ve[x][i];
        if (to == fa) continue;
        
        dep[to] = dep[x]+1;
        grand[to][0] = x;
        dfs(to, x);
    }
}

void init(){
    cnt = 1;
    root[0] = 0;
    t[0].l = t[0].r = t[0].sum = 0;
}

int getlca(int u, int v){
    if (dep[u] < dep[v]) swap(u, v); // u 是在下面的
    
    for(int i = 20; i >= 0; i--){
        if (dep[grand[u][i]] >= dep[v]) u = grand[u][i];
    }
    if (u == v) return u;
    
    for(int i = 20; i >= 0; i--){
        if (grand[u][i] != grand[v][i]){
            u = grand[u][i];
            v = grand[v][i];
        }
    }
    return grand[u][0];
}

int query(int t1, int t2, int t3, int t4, int k, int l, int r){
    int d = t[t[t1].l].sum+t[t[t2].l].sum-t[t[t3].l].sum-t[t[t4].l].sum;
    //printf("l = %d r = %d d = %d\n", l, r, d);
    if (l == r) return l;
    int m = (l+r)>>1;
    if (k <= d) return query(t[t1].l, t[t2].l, t[t3].l, t[t4].l, k, l, m);
    else return query(t[t1].r, t[t2].r, t[t3].r, t[t4].r, k-d, m+1, r);
}
int last=0;

int main() {
    //freopen("in.txt", "r", stdin);
    //freopen("out.txt", "w", stdout);
    cin >> n >> m;
    for(int i = 1; i <= n; i++){
        scanf("%d", &pre[i]);
        rank[i] = pre[i];
    }
    sort(rank+1, rank+1+n);
    ss = unique(rank+1, rank+1+n)-rank;
    int a, b;
    for(int i = 1; i < n; i++){
        scanf("%d%d", &a, &b);
        ve[a].push_back(b), ve[b].push_back(a);
    }
    dep[1] = 1;
    init();
    dfs(1, 0);
    
    int u, v, k;    
    for(int i = 1; i <= m; i++){
        scanf("%d%d%d", &u, &v, &k);
        u ^= last;
        //printf("*** u = %d v = %d\n", u, v);
        int lca = getlca(u, v);
        int ans = query(root[u], root[v], root[lca], root[grand[lca][0]], k, 1, n); 
        last = rank[ans];
        //printf("*** ans = %d\n", ans);
        printf("%d\n", rank[ans]);
    }
    return 0;
}

 

posted @ 2018-04-14 19:01  楼主好菜啊  阅读(641)  评论(0编辑  收藏  举报