SPOJ - COT 树上路径第k小
题意:求树上路径第k小
思路:开始想树剖到主席树上做, 但是其实不需要,我们都知道求树上2点的距离 l = deep[u] + deep[v] - 2*deep[lca(u,v)],这里的深度deep其实就是路径长度的前缀和,同理可以以此建主席树,每颗线段树的前一个版本就是它的父亲,那么查询 路径 u->v的第k小就是在rt[u] + rt[v] - rt[lca(u,v)] -rt[fa[lca(u,v)]],区间第k小是rt[L-1] 和rt[R] 2颗树同时向下并相减,那么这里就是4颗树同时向下并相加减,注意这个式子,因为rt[u] + rt[v] 计算了2次rt[lca(u,v)],所以要减去一次,减去rt[fa[lca(u,v)]]相当于区间第k小减去rt[L-1],难点主要在于理解主席树前缀的思想
AC代码:
#include "iostream" #include "iomanip" #include "string.h" #include "stack" #include "queue" #include "string" #include "vector" #include "set" #include "map" #include "algorithm" #include "stdio.h" #include "math.h" #pragma comment(linker, "/STACK:102400000,102400000") #define bug(x) cout<<x<<" "<<"UUUUU"<<endl; #define mem(a,x) memset(a,x,sizeof(a)) #define step(x) fixed<< setprecision(x)<< #define mp(x,y) make_pair(x,y) #define pb(x) push_back(x) #define ll long long #define endl ("\n") #define ft first #define sd second #define lrt (rt<<1) #define rrt (rt<<1|1) using namespace std; const ll mod=1e9+7; const ll INF = 1e18+1LL; const int inf = 1e9+1e8; const double PI=acos(-1.0); const int N=1e5+100; int n, m, a[N], ran[N]; int head[N*2], nex[N*2], to[N*2], deep[N], p[N][30], f[N], tot; int rt[N], ls[N*20], rs[N*20], sum[N*20], cnt; void add(int u, int v) { to[tot] = v; nex[tot] = head[u]; head[u] = tot++; } void init() { int i,j; for(j=1;(1<<j)<=n;j++) for(i=1;i<=n;i++) if(p[i][j-1]!=-1) p[i][j]=p[p[i][j-1]][j-1]; } int LCA(int a,int b) { int i,j; if(deep[a]<deep[b]) swap(a,b); for(i=0;(1<<i)<=deep[a];i++); i--; //使a,b两点的深度相同 for(j=i;j>=0;j--) if(deep[a]-(1<<j)>=deep[b]) a=p[a][j]; if(a==b)return a; //倍增法,每次向上进深度2^j,找到最近公共祖先的子结点 for(j=i;j>=0;j--) { if(p[a][j]!=-1&&p[a][j]!=p[b][j]) { a=p[a][j], b=p[b][j]; } } return p[a][0]; } void updata(int &cur, int l, int r, int p, int last) { cur = ++cnt; sum[cur] = sum[last]+1; if(l == r) return; ls[cur] = ls[last]; rs[cur] = rs[last]; int mid = l+r>>1; if(p<=mid) updata(ls[cur], l, mid, p, ls[last]); else updata(rs[cur], mid+1, r, p, rs[last]); } void dfs(int u, int fa) {//cout<<a[u]<<endl; deep[u] = deep[fa]+1, p[u][0] = fa, f[u] = fa; updata(rt[u], 1, n, ran[u], rt[fa]); for(int i=head[u]; i!=-1; i=nex[i]) { int v = to[i]; if(v == fa) continue; dfs(v, u); } } int query(int rt_u, int rt_v, int rt_lca, int rt_flca, int l, int r, int k) { if(l == r) return l; int t = sum[ls[rt_u]]+sum[ls[rt_v]]-sum[ls[rt_lca]]-sum[ls[rt_flca]]; int mid = l+r>>1; if(k<=t) return query(ls[rt_u], ls[rt_v], ls[rt_lca], ls[rt_flca], l, mid, k); else return query(rs[rt_u], rs[rt_v], rs[rt_lca], rs[rt_flca], mid+1, r, k-t); } struct Node{ int v, id; bool friend operator< (Node a, Node b) { return a.v<b.v; } }arr[N]; int main() { int u, v, k; memset(head, -1, sizeof(head)); scanf("%d %d", &n, &m); for(int i=1; i<=n; ++i) { scanf("%d", &a[i]); arr[i].id = i, arr[i].v = a[i]; } for(int i=1; i<n; ++i) { scanf("%d %d", &u, &v); add(u, v), add(v, u); } sort(arr+1, arr+1+n); for(int i=1; i<=n; ++i) ran[arr[i].id] = i; dfs(1, 0); init(); while(m--) { scanf("%d %d %d", &u, &v, &k); int lca = LCA(u, v); printf("%d\n", arr[query(rt[u], rt[v], rt[lca], rt[f[lca]], 1, n, k)].v); } return 0; }