Count on a tree SPOJ - COT

原题链接
考察:主席树
写错了LCA的板子,debug几个小时...
思路:
  和在一维数组上建立主席树不同,树上建主席树是以父节点为上一个版本,这里求(u,v)的第k小值,实际就是:

\[tr[u].cnt+tr[v].cnt-tr[lca].cnt-tr[fa[lca]].cnt \]

  写得比较繁琐,实际bfs和dfs可以只用一个

Code

#include <cstring>
#include <iostream>
#include <queue>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long LL;
const int N = 100010,M = 18;
int n, m, fa[N][M], idx, h[N],d[N],dist[N],tot,root[N];
LL w[N];
vector<LL> alls;
struct Road
{
   int to, ne;
}road[N << 1];
struct Node
{
   int l, r, cnt;
   Node operator=(const Node& x){
        this->l = x.l;
   	 this->r = x.r;
   	 this->cnt = x.cnt;
   	 return *this;	
   }
}tr[N*21];
int find(LL x)
{
   return lower_bound(alls.begin(), alls.end(), x) - alls.begin() + 1;
}
void add(int a,int b)
{
   road[idx].to = b, road[idx].ne = h[a], h[a] = idx++;
}
void bfs(int r)
{
   queue<int> q;
   q.push(r);
   memset(dist, 0x3f, sizeof dist);
   dist[r] = 0;dist[0] = -1;
   while (q.size())
   {
       int u = q.front();
       q.pop();
       for (int i = h[u]; ~i;i=road[i].ne)
       {
           int v = road[i].to;
           if(dist[v]>dist[u]+1)
           {
               fa[v][0] = u;
               dist[v] = dist[u] + 1;
               for (int j = 1; j < M;j++)
               fa[v][j] = fa[fa[v][j - 1]][j - 1];
               q.push(v);
           }
       }
   }
   
}
int lca(int a,int b)
{
   if(dist[a]<dist[b])
       swap(a, b);
   for (int i = M - 1; i >= 0;i--)
     if(dist[fa[a][i]]>=dist[b])
         a = fa[a][i];
   if(a==b)
       return a;
   for (int i = M - 1; i >= 0;i--)
     if(fa[a][i]!=fa[b][i])
         a = fa[a][i], b = fa[b][i];
   return fa[a][0];
}
int build(int l,int r)
{
   int p = ++tot;
   if(l==r)
       return p;
   int mid = l + r >> 1;
   tr[p].l = build(l, mid);
   tr[p].r = build(mid + 1, r);
   tr[p].cnt = 0;
   return p;
}
int insert(int last,int val,int l,int r)
{
   int p = ++tot;
   tr[p] = tr[last];
   if(l==r)
   {
       tr[p].cnt++;
       return p;
   }
   int mid = l + r >> 1;
   if(val<=mid)
       tr[p].l = insert(tr[last].l, val, l, mid);
   else
       tr[p].r = insert(tr[last].r, val, mid + 1, r);
   tr[p].cnt = tr[tr[p].l].cnt + tr[tr[p].r].cnt;
   return p;
}
void dfs(int u,int fat)
{
   root[u] = insert(root[fat], find(w[u]), 1, alls.size());
   for (int i = h[u]; ~i;i=road[i].ne)
   {
       int v = road[i].to;
       if(v==fat) continue;
       dfs(v, u);
   }
}
LL ask(int a,int b,int anc,int ancfa,int k,int l,int r)
{
   if(l==r)
       return alls[l-1];
   int cnt = tr[tr[a].l].cnt + tr[tr[b].l].cnt - tr[tr[anc].l].cnt - tr[tr[ancfa].l].cnt;
   int mid = l + r >> 1;
   if(k<=cnt)
       return ask(tr[a].l, tr[b].l, tr[anc].l, tr[ancfa].l,k,l,mid);
   else
       return ask(tr[a].r, tr[b].r, tr[anc].r, tr[ancfa].r, k - cnt,mid+1,r);
}
int main()
{
   scanf("%d%d", &n, &m);
   memset(h, -1, sizeof h);
   for (int i = 1; i <= n; i++) scanf("%lld", &w[i]), alls.push_back(w[i]);
   sort(alls.begin(), alls.end());
   alls.erase(unique(alls.begin(), alls.end()), alls.end());
   for (int i = 1; i < n;i++)
   {
       int a, b;
       scanf("%d%d", &a, &b);
       add(a, b), add(b, a);
   }
   bfs(1);
   root[0] = build(1, alls.size());
   dfs(1,0);
   while (m--)
   {
       int u, v, k;
       scanf("%d%d%d", &u, &v, &k);
       int anc = lca(u, v);
       printf("%lld\n", ask(root[u], root[v], root[anc], root[fa[anc][0]], k, 1, alls.size()));
   }
   return 0;
}

posted @ 2021-07-02 13:07  acmloser  阅读(27)  评论(0编辑  收藏  举报