Count on a tree 树上区间第K小
题意:求路径 u到v上的 第k小的权重。
题解:先DFS建数, 然后对于每个节点往上跑出一颗主席树, 然后每次更新。 查询的时候, u, v, k, 找到 z = lca(u,v) , p = anc[z][0], 然后对于这条路上左边子节点的个数就可以被表示为u,v的树- z,p的树上的值。
然后主要是LCA不怎么用写搓了, 找了好久的bug。 汗。
代码:
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define Fopen freopen("_in.txt","r",stdin); freopen("_out.txt","w",stdout); 4 #define LL long long 5 #define ULL unsigned LL 6 #define fi first 7 #define se second 8 #define pb push_back 9 #define max3(a,b,c) max(a,max(b,c)) 10 #define min3(a,b,c) min(a,min(b,c)) 11 typedef pair<int,int> pll; 12 const int INF = 0x3f3f3f3f; 13 const LL mod = (int)1e9+7; 14 const int N = 1e5 + 100; 15 const int M = 5e6 + 100; 16 vector<int> son[N]; 17 int lson[M], rson[M], cnt[M]; 18 int anc[N][20]; 19 int tot, t; 20 int a[N], w[N], deep[N], root[N]; 21 int id(int x){ 22 return lower_bound(a+1, a+1+t, x) - a; 23 } 24 int Build(int l, int r){ 25 int now = ++tot; 26 cnt[now] = 0; 27 if(l < r){ 28 int m = l+r >> 1; 29 lson[now] = Build(l, m); 30 rson[now] = Build(m+1, r); 31 } 32 return now; 33 } 34 int Update(int l, int r, int pre, int c, int v){ 35 int now = ++tot; 36 cnt[now] = cnt[pre] + v; 37 if(l < r){ 38 int m = l+r >> 1; 39 if(c <= m){ 40 rson[now] = rson[pre]; 41 lson[now] = Update(l, m, lson[pre], c, v); 42 } 43 else { 44 lson[now] = lson[pre]; 45 rson[now] = Update(m+1, r, rson[pre], c, v); 46 } 47 } 48 return now; 49 } 50 51 void dfs(int o, int u){ 52 deep[u] = deep[o] + 1; 53 root[u] = Update(1, t, root[o], id(w[u]), 1); 54 for(int i = 0; i < son[u].size(); i++){ 55 int v = son[u][i]; 56 if(v == o) continue; 57 anc[v][0] = u; 58 for(int j = 1; j < 19; j++) anc[v][j] = anc[anc[v][j-1]][j-1]; 59 dfs(u, v); 60 } 61 } 62 int lca(int u, int v){ 63 if(deep[u] < deep[v]) swap(u, v); 64 for(int i = 19; i >= 0; i--) 65 if(deep[anc[u][i]] >= deep[v]) u = anc[u][i]; 66 if(u == v) return v; 67 for(int i = 19; i >= 0; i--) 68 if(anc[u][i] != anc[v][i]) 69 u = anc[u][i], v = anc[v][i]; 70 return anc[u][0]; 71 } 72 int Query(int l, int r, int c1, int c2, int d1, int d2, int k){ 73 if(l == r) return a[l]; 74 int m = l+r >> 1; 75 int num = cnt[lson[c1]] + cnt[lson[c2]] - cnt[lson[d1]] - cnt[lson[d2]]; 76 if(num >= k) return Query(l, m, lson[c1], lson[c2], lson[d1], lson[d2], k); 77 else return Query(m+1, r, rson[c1], rson[c2], rson[d1], rson[d2], k-num); 78 } 79 int main(){ 80 int n, q, u, v, k, p, z; 81 scanf("%d%d", &n, &q); 82 for(int i = 1; i <= n; i++){ 83 scanf("%d", &w[i]); 84 a[i] = w[i]; 85 } 86 for(int i = 1; i < n; i++){ 87 scanf("%d%d", &u, &v); 88 son[u].pb(v); 89 son[v].pb(u); 90 } 91 sort(a+1, a+1+n); 92 for(int i = 1; i <= n; i++) a[++t] = a[i]; 93 root[0] = Build(1, t); 94 dfs(0, 1); 95 while(q--){ 96 scanf("%d%d%d", &u, &v, &k); 97 z = lca(u, v); 98 p = anc[z][0]; 99 printf("%d\n", Query(1, t, root[u], root[v], root[z], root[p], k)); 100 } 101 return 0; 102 }