洛谷 P4211 [LNOI2014]LCA
题意
给出一个 \(n\) 个节点的有根树(编号为 \(0\) 到 \(n-1\),根节点为 \(0\))。
一个点的深度定义为这个节点到根的距离 \(+1\),记为 \(dep[i]\)
用 \(LCA(i,j)\) 表示 \(i\) 与 \(j\) 的最近公共祖先。
有 \(q\) 次询问,每次询问给出 \(l\ r\ z\),求\[\sum\limits_{i=l}^r dep[\text{LCA}(i,z)] \]
思路
树链剖分
不喜欢从 \(0\) 开始编号,因此这里从 \(1\) 开始编号,输入的点都要 \(+1\)。
先单独考虑一次查询,我们发现 \(dep[i]\) 其实就是 \(i\) 点到根节点的点数(包括自己),所以单独一次查询时所做的操作就是将区间 \([l,r]\) 中的点到根节点的点上的权值加一,然后查询 \(z\) 到根节点的权值和,可能有些难理解,下面画图理解一下。
扩展到多个询问的情况,这个时候每次执行上面的操作就显得很不对。再看一下原来的式子:
稍加思考我们就可以发现,这个式子可以分成两个部分的差,也就是:
因为询问并不是强制在线的,所以我们可以将询问离线,把每一个询问分为 \(1\sim l-1\) 和 \(1\sim r\) 两个小询问, \(1\sim l-1\) 的询问的 \(tag\) 标记为 \(0\),表示我们要减去这个值。 \(1\sim r\) 的询问的 \(tag\) 标记为 \(1\),表示要加上这个值。
将询问按照端点进行排序,从小到大进行处理,如果当前节点小于询问的右端点,就一直加 \(now\) 到根节点的点权,不断增大 \(now\),直到 \(now=\) 询问端点,计算出当前询问 \(z\) 到根节点的值,然后用 \(tag\) 判断是加还是减。
上述操作可以用树剖套线段树实现,如果学过的话应该很容易理解。
时间复杂度 \(O(q\log^2 n)\)。
代码
/*
Name: P4211 [LNOI2014]LCA
Author: Loceaner
Date: 07/09/20 08:58
Description:
Debug: dfn[x]=++tot写成dfn[tp]=++tot
线段树中update忘记return
*/
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int A = 2e5 + 11;
const int B = 1e6 + 11;
const int mod = 201314;
const int inf = 0x3f3f3f3f;
inline int read() {
char c = getchar();
int x = 0, f = 1;
for ( ; !isdigit(c); c = getchar()) if (c == '-') f = -1;
for ( ; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
return x * f;
}
struct node { int to, nxt; } e[A];
int n, q, qcnt, cnt, head[A], ans[A];
struct ques { int r, z, num, flag; } a[A];
int tot, fa[A], siz[A], son[A], dfn[A], top[A], dep[A]; //树剖
inline void add(int from, int to) {
e[++cnt].to = to;
e[cnt].nxt = head[from];
head[from] = cnt;
}
bool cmp(ques a, ques b) {
return a.r < b.r;
}
namespace Seg {
#define lson rt << 1
#define rson rt << 1 | 1
struct tree { int l, r, sum, lazy; } t[A << 2];
inline void pushup(int rt) {
t[rt].sum = (t[lson].sum + t[rson].sum) % mod;
}
inline void pushdown(int rt) {
t[lson].sum += (t[lson].r - t[lson].l + 1) * t[rt].lazy, t[lson].sum %= mod;
t[rson].sum += (t[rson].r - t[rson].l + 1) * t[rt].lazy, t[rson].sum %= mod;
t[lson].lazy += t[rt].lazy, t[lson].lazy %= mod;
t[rson].lazy += t[rt].lazy, t[rson].lazy %= mod;
t[rt].lazy = 0;
}
void build(int rt, int l, int r) {
t[rt].l = l, t[rt].r = r;
if (l == r) return;
int mid = (l + r) >> 1;
build(lson, l, mid), build(rson, mid + 1, r);
pushup(rt);
}
void update(int rt, int l, int r, int x) {
if (l <= t[rt].l && t[rt].r <= r) {
t[rt].sum += (t[rt].r - t[rt].l + 1) * x, t[rt].sum %= mod;
t[rt].lazy += x, t[rt].lazy %= mod;
return;
}
if (t[rt].lazy) pushdown(rt);
int mid = (t[rt].l + t[rt].r) >> 1;
if (l <= mid) update(lson, l, r, x);
if (r > mid) update(rson, l, r, x);
pushup(rt);
}
int query(int rt, int l, int r) {
if (l <= t[rt].l && t[rt].r <= r) return t[rt].sum;
if (t[rt].lazy) pushdown(rt);
int mid = (t[rt].l + t[rt].r) >> 1, ans = 0;
if (l <= mid) ans += query(lson, l, r);
if (r > mid) ans += query(rson, l, r);
return ans % mod;
}
}
void prepare(int x, int f) {
fa[x] = f, siz[x] = 1, dep[x] = dep[f] + 1;
for (int i = head[x]; i; i = e[i].nxt) {
int to = e[i].to;
if (to == f) continue;
prepare(to, x), siz[x] += siz[to];
if (siz[to] > siz[son[x]]) son[x] = to;
}
}
void dfs(int x, int tp) {
top[x] = tp, dfn[x] = ++tot;
if (son[x]) dfs(son[x], tp);
for (int i = head[x]; i; i = e[i].nxt) {
int to = e[i].to;
if (to == fa[x] || to == son[x]) continue;
dfs(to, to);
}
}
inline void add_val(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
Seg::update(1, dfn[top[x]], dfn[x], 1);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
Seg::update(1, dfn[x], dfn[y], 1);
return;
}
inline int ask_val(int x, int y) {
int ans = 0;
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
ans += Seg::query(1, dfn[top[x]], dfn[x]);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
ans += Seg::query(1, dfn[x], dfn[y]);
return ans;
}
int main() {
n = read(), q = read();
for (int i = 2; i <= n; i++) {
fa[i] = read() + 1;
add(fa[i], i), add(i, fa[i]);
}
prepare(1, 0), dfs(1, 1), Seg::build(1, 1, n);
for (int i = 1; i <= q; i++) {
int l = read() + 1, r = read() + 1, z = read() + 1;
a[++qcnt] = (ques) {l - 1, z, i, 0};
a[++qcnt] = (ques) {r, z, i, 1};
}
sort(a + 1, a + 1 + qcnt, cmp);
int now = 1;
for (int i = 1; i <= qcnt; i++) {
while (now <= a[i].r) add_val(1, now++);
if (a[i].flag == 1) ans[a[i].num] += ask_val(1, a[i].z);
else ans[a[i].num] -= ask_val(1, a[i].z);
ans[a[i].num] += mod, ans[a[i].num] %= mod;
}
for (int i = 1; i <= q; i++) cout << ans[i] << '\n';
return 0;
}