SPOJ - COT 树上路径第k小

SPOJ - COT

题意:求树上路径第k小

思路:开始想树剖到主席树上做, 但是其实不需要,我们都知道求树上2点的距离 l = deep[u] + deep[v] - 2*deep[lca(u,v)],这里的深度deep其实就是路径长度的前缀和,同理可以以此建主席树,每颗线段树的前一个版本就是它的父亲,那么查询 路径 u->v的第k小就是在rt[u] + rt[v] - rt[lca(u,v)] -rt[fa[lca(u,v)]],区间第k小是rt[L-1] 和rt[R] 2颗树同时向下并相减,那么这里就是4颗树同时向下并相加减,注意这个式子,因为rt[u] + rt[v] 计算了2次rt[lca(u,v)],所以要减去一次,减去rt[fa[lca(u,v)]]相当于区间第k小减去rt[L-1],难点主要在于理解主席树前缀的思想

AC代码:

复制代码
#include "iostream"
#include "iomanip"
#include "string.h"
#include "stack"
#include "queue"
#include "string"
#include "vector"
#include "set"
#include "map"
#include "algorithm"
#include "stdio.h"
#include "math.h"
#pragma comment(linker, "/STACK:102400000,102400000")
#define bug(x) cout<<x<<" "<<"UUUUU"<<endl;
#define mem(a,x) memset(a,x,sizeof(a))
#define step(x) fixed<< setprecision(x)<<
#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define ll long long
#define endl ("\n")
#define ft first
#define sd second
#define lrt (rt<<1)
#define rrt (rt<<1|1)
using namespace std;
const ll mod=1e9+7;
const ll INF = 1e18+1LL;
const int inf = 1e9+1e8;
const double PI=acos(-1.0);
const int N=1e5+100;

int n, m, a[N], ran[N];
int head[N*2], nex[N*2], to[N*2], deep[N], p[N][30], f[N], tot;
int rt[N], ls[N*20], rs[N*20], sum[N*20], cnt;

void add(int u, int v) {
    to[tot] = v;
    nex[tot] = head[u];
    head[u] = tot++;
}

void init() {
    int i,j;
    for(j=1;(1<<j)<=n;j++)
        for(i=1;i<=n;i++)
            if(p[i][j-1]!=-1)
                p[i][j]=p[p[i][j-1]][j-1];
}
int LCA(int a,int b) {
    int i,j;
    if(deep[a]<deep[b]) swap(a,b);
    for(i=0;(1<<i)<=deep[a];i++);
    i--; //使a,b两点的深度相同
    for(j=i;j>=0;j--)
        if(deep[a]-(1<<j)>=deep[b])
            a=p[a][j];
    if(a==b)return a; //倍增法,每次向上进深度2^j,找到最近公共祖先的子结点
    for(j=i;j>=0;j--) {
        if(p[a][j]!=-1&&p[a][j]!=p[b][j]) {
            a=p[a][j], b=p[b][j];
        }
    }
    return p[a][0];
}

void updata(int &cur, int l, int r, int p, int last) {
    cur = ++cnt;
    sum[cur] = sum[last]+1;
    if(l == r) return;
    ls[cur] = ls[last];
    rs[cur] = rs[last];
    int mid = l+r>>1;
    if(p<=mid) updata(ls[cur], l, mid, p, ls[last]);
    else updata(rs[cur], mid+1, r, p, rs[last]);
}

void dfs(int u, int fa) {//cout<<a[u]<<endl;
    deep[u] = deep[fa]+1, p[u][0] = fa, f[u] = fa;
    updata(rt[u], 1, n, ran[u], rt[fa]);
    for(int i=head[u]; i!=-1; i=nex[i]) {
        int v = to[i];
        if(v == fa) continue;
        dfs(v, u);
    }
}

int query(int rt_u, int rt_v, int rt_lca, int rt_flca, int l, int r, int k) {
    if(l == r) return l;
    int t = sum[ls[rt_u]]+sum[ls[rt_v]]-sum[ls[rt_lca]]-sum[ls[rt_flca]];
    int mid = l+r>>1;
    if(k<=t) return query(ls[rt_u], ls[rt_v], ls[rt_lca], ls[rt_flca], l, mid, k);
    else return query(rs[rt_u], rs[rt_v], rs[rt_lca], rs[rt_flca], mid+1, r, k-t);
}

struct Node{
    int v, id;
    bool friend operator< (Node a, Node b) {
        return a.v<b.v;
    }
}arr[N];

int main() {
    int u, v, k;
    memset(head, -1, sizeof(head));
    scanf("%d %d", &n, &m);
    for(int i=1; i<=n; ++i) {
        scanf("%d", &a[i]);
        arr[i].id = i, arr[i].v = a[i];
    }
    for(int i=1; i<n; ++i) {
        scanf("%d %d", &u, &v);
        add(u, v), add(v, u);
    }
    sort(arr+1, arr+1+n);
    for(int i=1; i<=n; ++i) ran[arr[i].id] = i;
    dfs(1, 0);
    init();
    while(m--) {
        scanf("%d %d %d", &u, &v, &k);
        int lca = LCA(u, v);
        printf("%d\n", arr[query(rt[u], rt[v], rt[lca], rt[f[lca]], 1, n, k)].v);
    }
    return 0;
}
复制代码

 

posted on   lazzzy  阅读(182)  评论(0编辑  收藏  举报

努力加载评论中...

导航

点击右上角即可分享
微信分享提示