【洛谷 P2633】 Count on a tree(主席树,树上差分)
题目链接
思维难度0
实现难度7
建出主席树后用两点的状态减去lca和lca父亲的状态,然后在新树上跑第\(k\)小
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MAXN = 100010;
const int MAXM = 100010;
inline int read(){
int s = 0, w = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){ if(ch == '-') w = -1; ch = getchar(); }
while(ch >= '0' && ch <= '9'){ s = s * 10 + ch - '0'; ch = getchar(); }
return s * w;
}
int cnt;
struct Pt{
int lc, rc, val;
}t[MAXN * 21];
int build(int l, int r){
int id = ++cnt;
if(l == r) return id;
int mid = (l + r) >> 1;
t[id].lc = build(l, mid);
t[id].rc = build(mid + 1, r);
return id;
}
int update(int p, int l, int r, int x){
int id = ++cnt; t[id] = t[p];
if(l == r){ t[id].val++; return id; }
int mid = (l + r) >> 1;
if(x <= mid) t[id].lc = update(t[p].lc, l, mid, x);
else t[id].rc = update(t[p].rc, mid + 1, r, x);
t[id].val = t[t[id].lc].val + t[t[id].rc].val;
return id;
}
struct Edge{
int next, to;
}e[MAXN << 1];
int head[MAXN], num, w[MAXN], dep[MAXN], root[MAXN];
inline void Add(int from, int to){
e[++num].to = to; e[num].next = head[from]; head[from] = num;
e[++num].to = from; e[num].next = head[to]; head[to] = num;
}
int n, m, tot, f[MAXN][20];
struct lsh{
int val, id;
int operator < (const lsh A) const{
return val < A. val;
}
}p[MAXN];
void dfs(int u, int fa){
root[u] = update(root[fa], 1, tot, w[u]);
f[u][0] = fa; dep[u] = dep[fa] + 1;
for(int i = head[u]; i; i = e[i].next)
if(e[i].to != fa)
dfs(e[i].to, u);
}
int a, b, c, val[MAXN], ans, lca;
int LCA(int u, int v){
if(dep[u] > dep[v]) swap(u, v);
int tmp = dep[v] - dep[u];
for(int i = 0; i <= 19; ++i)
if(tmp & (1 << i))
v = f[v][i];
for(int i = 19; ~i; --i)
if(f[u][i] != f[v][i])
u = f[u][i], v = f[v][i];
return u == v ? u : f[u][0];
}
int solve(int l, int r, int a, int b, int c, int d, int k){
if(l == r) return l;
int lcnt = t[t[a].lc].val + t[t[b].lc].val - t[t[c].lc].val - t[t[d].lc].val, mid = (l + r) >> 1;
if(lcnt < k) return solve(mid + 1, r, t[a].rc, t[b].rc, t[c].rc, t[d].rc, k - lcnt);
else return solve(l, mid, t[a].lc, t[b].lc, t[c].lc, t[d].lc, k);
}
int main(){
n = read(); m = read();
for(int i = 1; i <= n; ++i)
p[i].val = read(), p[i].id = i;
sort(p + 1, p + n + 1);
for(int i = 1; i <= n; ++i)
if(p[i].val != p[i - 1].val){
w[p[i].id] = ++tot;
val[tot] = p[i].val;
}
else w[p[i].id] = tot;
for(int i = 1; i < n; ++i)
Add(read(), read());
root[0] = build(1, tot); dfs(1, 0);
for(int j = 1; j <= 19; ++j)
for(int i = 1; i <= n; ++i)
f[i][j] = f[f[i][j - 1]][j - 1];
for(int i = 1; i <= m; ++i){
a = read() ^ ans; b = read(); c = read(); lca = LCA(a, b);
printf("%d\n", ans = val[solve(1, tot, root[a], root[b], root[lca], root[f[lca][0]], c)]);
}
return 0;
}