[Luogu] P4211 [LNOI2014]LCA

\(Link\)

Description

给出一个\(n\)个节点的有根树(编号为\(0\)\(n−1\),根节点为\(0\))。

一个点的深度定义为这个节点到根的距离 \(+1\)

\(dep[i]\)表示点\(i\)的深度,\(LCA(i,j)\)表示\(i\)\(j\)的最近公共祖先。

\(q\)次询问,每次询问给出\(l\ r\ z\),求\(\sum_{i=l}^r dep[LCA(i,z)]\)

Solution

注意到\(dep[LCA(x, y)]\)就是先把\(x\)到根路径上的点全部加一,再查询\(y\)到根路径上的点权之和。(因为\(LCA(x,y)\)到根走过的点,就是\(x\)到根的和\(y\)到根的路径的重复部分。)

而显然\(\sum_{i=l}^r dep[LCA(i,z)]=\sum_{i=1}^r dep[LCA(i,z)]-\sum_{i=1}^{l-1} dep[LCA(i,z)]\)。所以我们现在要解决如何快速求\(\sum_{i=1}^x dep[LCA(i,y)]\)

根据之前的结论,这其实就是把从\(1\)\(x\)的所有点到根路径上的点全部加一,再查询\(y\)到根路径上的点权之和。我们很容易会想到树链剖分。

但对每一个\(y\)都做一遍类似的操作显然不现实。注意到\(\sum_{i=1}^x\)之间是有很多重复的。于是我们可以把每个询问拆成\((l-1,-1)\)\((r,1)\),然后把\(pos\)从小到大排序。维护一个\(now\)指针,在把\(now\)从小到大指向当前询问\(pos\)的同时\(add(1, now)\),即把\(now\)到根路径上的点全部加一。\(now\)指向\(pos\)之后,设\(z\)是当前询问的\(z\)\(calc(1,pos)\)算出当前\(pos\)到根路径上的点权之和,就是\(\sum_{i=1}^{pos} dep[LCA(i,z)]\),贡献对应的加在或减在\(res[id[pos]]\)上。

路径加和求和都是树剖基本操作了。

Code

#include <bits/stdc++.h>

using namespace std;

#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)

const int mod = 201314;

int n, q, tot, cnt, tt, res[50005], id[50005], top[50005], f[50005], dep[50005], sz[50005], mson[50005], hd[50005], to[100005], nxt[100005];

struct node
{
	int l, r, sum, add;
}t[200005];

struct rd
{
	int pos, id, z, fl;
}ask[100005];

int read()
{
	int x = 0, fl = 1; char ch = getchar();
	while (ch < '0' || ch > '9') { if (ch == '-') fl = -1; ch = getchar();}
	while (ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + ch - '0'; ch = getchar();}
	return x * fl;
}

int cmp(rd p, rd q)
{
	return p.pos < q.pos;
}

void add(int x, int y)
{
	tot ++ ;
	to[tot] = y;
	nxt[tot] = hd[x];
	hd[x] = tot;
	return;
}

void push_up(int p)
{
	t[p].sum = (t[ls(p)].sum + t[rs(p)].sum) % mod;
	return;
}

void push_down(int p)
{
	if (!t[p].add) return;
	t[ls(p)].add = (t[ls(p)].add + t[p].add) % mod;
	t[rs(p)].add = (t[rs(p)].add + t[p].add) % mod;
	t[ls(p)].sum = (t[ls(p)].sum + t[p].add * (t[ls(p)].r - t[ls(p)].l + 1) % mod) % mod;
	t[rs(p)].sum = (t[rs(p)].sum + t[p].add * (t[rs(p)].r - t[rs(p)].l + 1) % mod) % mod;
	t[p].add = 0;
	return;
}

void update(int p, int l0, int r0, int d)
{
	if (l0 <= t[p].l && t[p].r <= r0)
	{
		t[p].add = (t[p].add + d) % mod;
		t[p].sum = (t[p].sum + (t[p].r - t[p].l + 1) * d % mod) % mod;
		return;
	}
	push_down(p);
	int mid = (t[p].l + t[p].r) >> 1;
	if (l0 <= mid) update(ls(p), l0, r0, d);
	if (r0 > mid) update(rs(p), l0, r0, d);
	push_up(p);
	return;
}


int query(int p, int l0, int r0)
{
	if (l0 <= t[p].l && t[p].r <= r0) return t[p].sum % mod;
	push_down(p);
	int mid = (t[p].l + t[p].r) >> 1, cnt = 0;
	if (l0 <= mid) cnt = cnt + query(ls(p), l0, r0);
	if (r0 > mid) cnt = cnt + query(rs(p), l0, r0);
	return cnt;
}

void build(int p, int l0, int r0)
{
	t[p].l = l0; t[p].r = r0;
	if (l0 == r0) return;
	int mid = (l0 + r0) >> 1;
	build(ls(p), l0, mid);
	build(rs(p), mid + 1, r0);
	push_up(p);
	return;
}

void dfs1(int x, int fa)
{
	sz[x] = 1;
	int mx = -1;
	for (int i = hd[x]; i; i = nxt[i])
	{
		int y = to[i];
		if (y == fa) continue;
		dep[y] = dep[x] + 1;
		f[y] = x;
		dfs1(y, x);
		sz[x] += sz[y];
		if (sz[y] > mx)
		{
			mx = sz[y];
			mson[x] = y;
		}
	}
	return;
}

void dfs2(int x, int tp)
{
	id[x] = ++ cnt;
	top[x] = tp;
	if (!mson[x]) return;
	dfs2(mson[x], tp);
	for (int i = hd[x]; i; i = nxt[i])
	{
		int y = to[i];
		if (y == f[x] || y == mson[x]) continue;
		dfs2(y, y);
	}
	return;
}

void q1(int x, int y)
{
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]]) swap(x, y);
		update(1, id[top[x]], id[x], 1);
		x = f[top[x]];
	}
	if (dep[x] > dep[y]) swap(x, y);
	update(1, id[x], id[y], 1);
	return;
}

int q2(int x, int y)
{
	int sum = 0;
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]]) swap(x, y);
		sum = (sum + query(1, id[top[x]], id[x])) % mod;
		x = f[top[x]];
	}
	if (dep[x] > dep[y]) swap(x, y);
	sum = (sum + query(1, id[x], id[y])) % mod;
	return sum;
}

int main()
{
	n = read(); q = read();
	for (int i = 1; i <= n - 1; i ++ )
	{
		int x = read(); x ++ ;
		add(x, i + 1); add(i + 1, x);
	}
	dfs1(1, 0); dfs2(1, 1);
	build(1, 1, n);
	for (int i = 1; i <= q; i ++ )
	{
		int l = read(), r = read(), z = read();
		l ++ ; r ++ ; z ++ ;
		tt ++ ; ask[tt].pos = l - 1; ask[tt].id = i; ask[tt].z = z; ask[tt].fl = -1;
		tt ++ ; ask[tt].pos = r; ask[tt].id = i; ask[tt].z = z; ask[tt].fl = 1;
	}
	sort(ask + 1, ask + tt + 1, cmp);
	int now = 0;
	for (int i = 1; i <= tt; i ++ )
	{
		while (now < ask[i].pos) now ++ , q1(1, now);
		res[ask[i].id] = (res[ask[i].id] + q2(1, ask[i].z) * ask[i].fl + mod) % mod;
	}
	for (int i = 1; i <= q; i ++ )
		printf("%d\n", res[i]);
	return 0;
}
posted @ 2020-11-10 15:21  andysj  阅读(96)  评论(0编辑  收藏  举报