Count on a tree SPOJ - COT
原题链接
考察:主席树
写错了LCA的板子,debug几个小时...
思路:
和在一维数组上建立主席树不同,树上建主席树是以父节点为上一个版本,这里求(u,v)的第k小值,实际就是:
\[tr[u].cnt+tr[v].cnt-tr[lca].cnt-tr[fa[lca]].cnt
\]
写得比较繁琐,实际bfs和dfs可以只用一个
Code
#include <cstring>
#include <iostream>
#include <queue>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long LL;
const int N = 100010,M = 18;
int n, m, fa[N][M], idx, h[N],d[N],dist[N],tot,root[N];
LL w[N];
vector<LL> alls;
struct Road
{
int to, ne;
}road[N << 1];
struct Node
{
int l, r, cnt;
Node operator=(const Node& x){
this->l = x.l;
this->r = x.r;
this->cnt = x.cnt;
return *this;
}
}tr[N*21];
int find(LL x)
{
return lower_bound(alls.begin(), alls.end(), x) - alls.begin() + 1;
}
void add(int a,int b)
{
road[idx].to = b, road[idx].ne = h[a], h[a] = idx++;
}
void bfs(int r)
{
queue<int> q;
q.push(r);
memset(dist, 0x3f, sizeof dist);
dist[r] = 0;dist[0] = -1;
while (q.size())
{
int u = q.front();
q.pop();
for (int i = h[u]; ~i;i=road[i].ne)
{
int v = road[i].to;
if(dist[v]>dist[u]+1)
{
fa[v][0] = u;
dist[v] = dist[u] + 1;
for (int j = 1; j < M;j++)
fa[v][j] = fa[fa[v][j - 1]][j - 1];
q.push(v);
}
}
}
}
int lca(int a,int b)
{
if(dist[a]<dist[b])
swap(a, b);
for (int i = M - 1; i >= 0;i--)
if(dist[fa[a][i]]>=dist[b])
a = fa[a][i];
if(a==b)
return a;
for (int i = M - 1; i >= 0;i--)
if(fa[a][i]!=fa[b][i])
a = fa[a][i], b = fa[b][i];
return fa[a][0];
}
int build(int l,int r)
{
int p = ++tot;
if(l==r)
return p;
int mid = l + r >> 1;
tr[p].l = build(l, mid);
tr[p].r = build(mid + 1, r);
tr[p].cnt = 0;
return p;
}
int insert(int last,int val,int l,int r)
{
int p = ++tot;
tr[p] = tr[last];
if(l==r)
{
tr[p].cnt++;
return p;
}
int mid = l + r >> 1;
if(val<=mid)
tr[p].l = insert(tr[last].l, val, l, mid);
else
tr[p].r = insert(tr[last].r, val, mid + 1, r);
tr[p].cnt = tr[tr[p].l].cnt + tr[tr[p].r].cnt;
return p;
}
void dfs(int u,int fat)
{
root[u] = insert(root[fat], find(w[u]), 1, alls.size());
for (int i = h[u]; ~i;i=road[i].ne)
{
int v = road[i].to;
if(v==fat) continue;
dfs(v, u);
}
}
LL ask(int a,int b,int anc,int ancfa,int k,int l,int r)
{
if(l==r)
return alls[l-1];
int cnt = tr[tr[a].l].cnt + tr[tr[b].l].cnt - tr[tr[anc].l].cnt - tr[tr[ancfa].l].cnt;
int mid = l + r >> 1;
if(k<=cnt)
return ask(tr[a].l, tr[b].l, tr[anc].l, tr[ancfa].l,k,l,mid);
else
return ask(tr[a].r, tr[b].r, tr[anc].r, tr[ancfa].r, k - cnt,mid+1,r);
}
int main()
{
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
for (int i = 1; i <= n; i++) scanf("%lld", &w[i]), alls.push_back(w[i]);
sort(alls.begin(), alls.end());
alls.erase(unique(alls.begin(), alls.end()), alls.end());
for (int i = 1; i < n;i++)
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
bfs(1);
root[0] = build(1, alls.size());
dfs(1,0);
while (m--)
{
int u, v, k;
scanf("%d%d%d", &u, &v, &k);
int anc = lca(u, v);
printf("%lld\n", ask(root[u], root[v], root[anc], root[fa[anc][0]], k, 1, alls.size()));
}
return 0;
}