Luogu P2633 Count on a tree

[题目链接 \(Click\) \(Here\)](P2633 Count on a tree)

树上建主席树,方法和序列上相差不多。就是被最大值卡\(RE\)好几次。。。

#include <bits/stdc++.h>
using namespace std;

const int N = 100010;
#define int long long

int n, m, cnt, val[N], head[N];

struct edge {
	int nxt, to;
	edge (int _nxt = 0, int _to = 0) {
		nxt = _nxt, to = _to;
	}
}e[N << 1];

void add_len (int u, int v) {
	e[++cnt] = edge (head[u], v); head[u] = cnt;
	e[++cnt] = edge (head[v], u); head[v] = cnt;
}

int tot, rt[N];

struct Segment_Node {
	int ls, rs, sz;
}t[N << 7];

#define mid ((l + r) >> 1)

int modify (int _rt, int l, int r, int w) {
	int p = ++tot;
	t[p].sz = t[_rt].sz + 1;
	if (l != r) {
		if (w <= mid) {
			t[p].ls = modify (t[_rt].ls, l, mid, w), t[p].rs = t[_rt].rs;
		} else {
			t[p].rs = modify (t[_rt].rs, mid + 1, r, w), t[p].ls = t[_rt].ls;
		}
	} else {
		t[p].ls = t[p].rs = 0;
	}
	return p;
}

int deep[N], fa[N][21];

void dfs (int u, int _fa) {
	fa[u][0] = _fa;
	deep[u] = deep[_fa] + 1;
	rt[u] = modify (rt[_fa], 0, 1e10, val[u]);
	for (int i = 1; (1 << i) <= deep[u]; ++i) {
		fa[u][i] = fa[fa[u][i - 1]][i - 1];
	}
	for (int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if (v != _fa) {
			dfs (v, u);
		}
	}
} 

int lca (int u, int v) {
	if (deep[u] < deep[v]) swap (u, v);
	for (int i = 20; i >= 0; --i) {
		if (deep[u] - (1 << i) >= deep[v]) {
			u = fa[u][i];
		}
	}
	if (u == v) return u;
	for (int i = 20; i >= 0; --i) {
		if (fa[u][i] != fa[v][i]) {
			u = fa[u][i];
			v = fa[v][i];
		}
	}
	return fa[u][0];
}

int query (int u, int v, int k) {
	int _lca = lca (u, v);
	int l1 = rt[_lca], r1 = rt[v];
	int l2 = rt[fa[_lca][0]], r2 = rt[u];
	int l = 0, r = 1e10;
	while (l < r) {
		int lch = 0;
		lch += t[t[r1].ls].sz - t[t[l1].ls].sz;
	    lch += t[t[r2].ls].sz - t[t[l2].ls].sz;
		if (k <= lch) {
			l1 = t[l1].ls, r1 = t[r1].ls;
			l2 = t[l2].ls, r2 = t[r2].ls;
			r = mid;
		} else {
			l1 = t[l1].rs, r1 = t[r1].rs;
			l2 = t[l2].rs, r2 = t[r2].rs;
			l = mid + 1;
			k -= lch;
		}
	}
	return r;
}

signed main () {
	// freopen ("data.in", "r", stdin);
	t[0].sz = t[0].ls = t[0].rs = 0;
	cin >> n >> m;
	for (int i = 1; i <= n; ++i) cin >> val[i];
	for (int i = 1; i <= n - 1; ++i) {
		static int u, v;
		cin >> u >> v;
		add_len (u, v);
	}
	dfs (1, 0);
	for (int i = 1; i <= m; ++i) {
		static int u, v, k, last_ans;
		cin >> u >> v >> k;
		u ^= last_ans;
		last_ans = query (u, v, k);
		cout << last_ans << endl;
	}
} 

posted @ 2019-03-15 19:02  maomao9173  阅读(100)  评论(0编辑  收藏  举报