P9992 [Ynoi Easy Round 2024] TEST_130 题解

线段树合并被卡常了。

这个题的题意有些瑕疵,即 dd 的范围。如果 dd 大于等于 uu 子树最大距离,那我可以任意取 dd,这个集合岂不是无穷大了?

实际上是,输入的时候存在 dd 大于等于这个最大距离,但我们求的 dd' 的取值必须小于等于子树最大距离。

接着考虑做法:

显然我们求的 wN(w,d)w' \in N(w,d),因为 dd 总是大于等于 00,所以无论如何都有 wN(w,d)w' \in N(w',d'),又 N(w,d)N(w,d)N(w', d') \subset N(w,d),所以一定有 wN(w,d)w' \in N(w,d)

考虑 wN(w,d)w' \in N(w,d) 的贡献。分两种:

  1. ww' 为根的子树完全包含于 N(w,d)N(w,d) 内,此时的贡献是子树高度 +1+1

  2. ww' 为根的子树与 N(w,d)N(w,d) 有交,但不完全包含于 N(w,d)N(w,d)。此时贡献为 depw+ddepw+1dep_w+d-dep_{w'}+1

我们如果处理出每个点深度 depudep_u 以及子树内最大深度 mdumd_u,则上面两种情况分别是 mdwdepw+dmd_{w'} \leq dep_w + dmdw>depw+dmd_{w'} > dep_w + d,贡献的第一个是 mdwdepw+1md_{w'} - dep_{w'}+1,第二个贡献就是上面的那个。

显然这两类都可以通过树状数组或线段树算出。问题在于如何限制 wN(w,d)w' \in N(w,d)。考虑 wN(w,d)w' \in N(w,d) 的充要条件是 ww'ww 子树内求 depwdepw+ddep_{w'} \leq dep_w + d。第一个条件容易考虑,第二个条件只需要对上文第二个情况容斥下,减去 depw>depw+ddep_{w'} > dep_w+d 的贡献。

在子树内的情况显然可以通过把 DFS 序弄下来跑扫描线做,这样可以用树状数组维护。然而用线段树合并常数大,于是我写线段树合并卡了很久。但复杂度都是 O(nlogn)O(n \log n)

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <cassert>
#include <queue>
#include <string>
#include <execution>
#include <vector>
#include <set>
using namespace std;

#define int unsigned

constexpr int N = 1e6 + 5;

int n, m;
basic_string<int> G[N], q[N];
int d[N];

unsigned long long ans[N];
int x, v;
int nl, nr;

struct Nd
{
	int dep, md;
	Nd() = default;
	Nd(int d, int m) : dep(d), md(m) {}
}g[N];

int v2;
int mdd;

int idx;
struct Noded
{
	int lson, rson;
	unsigned long long sum;
	int cnt;
}tr[N * 5];
int rubbish[N * 5], cnt;
#define newnode() (cnt ? rubbish[cnt--] : ++idx)
inline void ins(int& u, int l, int r)
{
	if (!u) u = newnode();
	if (l == r)
	{
		++tr[u].cnt;
		tr[u].sum += x;
		return;
	}
	int mid(l + r >> 1);
	(x <= mid ? ins(tr[u].lson, l, mid) : ins(tr[u].rson, mid + 1, r));
	tr[u].sum += x;
	++tr[u].cnt;
	return;
}
inline unsigned long long qsum(int u, int l, int r)
{
	if (!u) return 0LL;
	if (l >= nl)
	{
		return tr[u].sum;
	}
	int mid(l + r >> 1);
	unsigned long long sum = 0;
	(nl <= mid ? sum = qsum(tr[u].lson, l, mid) : 0);
	((mid ^ n) ? sum += qsum(tr[u].rson, mid + 1, r) : 0);
	return sum;
}
inline int qcnt(int u, int l, int r)
{
	if (!u) return 0;
	if (l >= nl)
	{
		return tr[u].cnt;
	}
	int mid(l + r >> 1), sum = 0;
	(nl <= mid ? sum = qcnt(tr[u].lson, l, mid) : 0);
	((mid ^ n) ? sum += qcnt(tr[u].rson, mid + 1, r) : 0);
	return sum;
}
inline int merge(int p, int q, int l, int r)
{
	if (!p) return q;
	if (!q) return p;
	if (l == r)
	{
		tr[p].sum += tr[q].sum;
		tr[p].cnt += tr[q].cnt;
		rubbish[++cnt] = q;
		tr[q].sum = tr[q].lson = tr[q].rson = tr[q].cnt = 0;
		return p;
	}
	int mid(l + r >> 1);
	tr[p].lson = merge(tr[p].lson, tr[q].lson, l, mid);
	tr[p].rson = merge(tr[p].rson, tr[q].rson, mid + 1, r);
	tr[p].sum += tr[q].sum;
	tr[p].cnt += tr[q].cnt;
	rubbish[++cnt] = q;
	tr[q].sum = tr[q].lson = tr[q].rson = tr[q].cnt = 0;
	return p;
}

int idx2;
struct Node2
{
	int lson, rson;
	unsigned long long sum;
	int cnt;
	unsigned long long sum2;
}tr2[N * 5];
int rubbish2[N * 5], cnt2;
#define newnode() (cnt2 ? rubbish2[cnt2--] : ++idx2)
inline void ins2(int& u, int l, int r)
{
	if (!u) u = newnode();
	if (l == r)
	{
		tr2[u].sum2 += v2;
		++tr2[u].cnt;
		tr2[u].sum += v;
		return;
	}
	int mid(l + r >> 1);
	(x <= mid ? ins2(tr2[u].lson, l, mid) : ins2(tr2[u].rson, mid + 1, r));
	tr2[u].sum += v;
	tr2[u].sum2 += v2;
	++tr2[u].cnt;
	return;
}
inline unsigned long long qsum_(int u, int l, int r)
{
	if (!u) return 0llu;
	if (r <= nr)
	{
		return tr2[u].sum;
	}
	int mid(l + r >> 1);
	unsigned long long sum(qsum_(tr2[u].lson, l, mid));
	(nr > mid ? sum += qsum_(tr2[u].rson, mid + 1, r) : 0);
	return sum;
}
inline unsigned long long qsum22(int u, int l, int r)
{
	if (!u) return 0llu;
	if (l >= nl)
	{
		return tr2[u].sum2;
	}
	int mid(l + r >> 1);
	unsigned long long sum = 0;
	(nl <= mid ? sum = qsum22(tr2[u].lson, l, mid) : 0);
	((mid ^ n) ? sum += qsum22(tr2[u].rson, mid + 1, r) : 0);
	return sum;
}
inline int qcnt2(int u, int l, int r)
{
	if (!u) return 0;
	if (l >= nl)
	{
		return tr2[u].cnt;
	}
	int mid(l + r >> 1), sum = 0;
	if (nl <= mid) sum = qcnt2(tr2[u].lson, l, mid);
	if (mid ^ n) sum += qcnt2(tr2[u].rson, mid + 1, r);
	return sum;
}

inline int merge2(int p, int q, int l, int r)
{
	if (!p) return q;
	if (!q) return p;
	if (!(l ^ r))
	{
		tr2[p].sum += tr2[q].sum;
		tr2[p].cnt += tr2[q].cnt;
		tr2[p].sum2 += tr2[q].sum2;
		rubbish2[++cnt2] = q;
		tr2[q].sum = tr2[q].lson = tr2[q].rson = tr2[q].cnt = tr2[q].sum2 = 0;
		return p;
	}
	int mid(l + r >> 1);
	tr2[p].sum += tr2[q].sum;
	tr2[p].sum2 += tr2[q].sum2;
	tr2[p].cnt += tr2[q].cnt;
	tr2[p].lson = merge2(tr2[p].lson, tr2[q].lson, l, mid);
	tr2[p].rson = merge2(tr2[p].rson, tr2[q].rson, mid + 1, r);
	rubbish2[++cnt2] = q;
	tr2[q].sum = tr2[q].lson = tr2[q].rson = tr2[q].cnt = tr2[q].sum2 = 0;
	return p;
}

static void dfs(int u)
{
	for_each(G[u].begin(), G[u].end(), [&](int j)
	{
		g[j].dep = g[j].md = g[u].dep + 1;
		dfs(j);
		g[u].md = (g[j].md > g[u].md ? g[j].md : g[u].md);
	});
}

struct Node
{
	int r1, r2;
}rt[N];

int scnt;

static void solve(int u)
{
	if (scnt == m) return;
	for_each(G[u].begin(), G[u].end(), [&](int j)
	{
		solve(j);
		rt[u].r1 = merge(rt[u].r1, rt[j].r1, 1, n);
		rt[u].r2 = merge2(rt[u].r2, rt[j].r2, 1, n);
	});
	int nr1(rt[u].r1), nr2(rt[u].r2), md(g[u].md), dep(g[u].dep);
	for_each(q[u].begin(), q[u].end(), [&](int id)
	{
		int d(::d[id]);
		nr = dep + d;
		unsigned long long res(qsum_(rt[u].r2, 1, n));
		nl = nr + 1;
		if (nl <= n)
		{
			res += 1llu * qcnt2(nr2, 1, n) * nl - qsum22(nr2, 1, n);
			res -= 1llu * qcnt(nr1, 1, n) * nl - qsum(nr1, 1, n);
		}
		ans[id] = res + (d + 1 < md - dep + 1 ? d + 1 : md - dep + 1);
		scnt++;
	});
	x = dep;
	ins(rt[u].r1, 1, n);
	x = md, v = md - dep + 1, v2 = dep;
	ins2(rt[u].r2, 1, n);
}

namespace FastIO
{
	char* p1, * p2, buf[1 << 22];
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22), stdin), p1 == p2) ? EOF : *p1++)

	inline int read()
	{
		int x = 0;
		char ch = getchar();
		while (!(ch ^ 48))
		{
			ch = getchar();
		}
		while (ch >= '0' & ch <= '9')
		{
			x = (x + (x << 2) << 1) + (ch ^ '0');
			ch = getchar();
		}
		return x;
	}

	void write(unsigned long long x)
	{
		if (x > 9) write(x / 10);
		putchar(x % 10 ^ 48);
	}

	inline void writeln(unsigned long long x) {
		write(x);
		putchar('\n');
	}
}

signed main()
{
	n = FastIO::read(), m = FastIO::read();
	for (int i = 2; i <= n; ++i)
	{
		G[FastIO::read()] += i;
	}
	g[1].dep = g[1].md = 1;
	dfs(1);
	for (int i = 1; i <= m; ++i)
	{
		int u(FastIO::read());
		d[i] = FastIO::read();
		q[u] += i;
	}
	n = g[1].md;
	solve(1);
	for_each(ans + 1, ans + m + 1, [&](unsigned long long res) {FastIO::writeln(res);});
	return 0;
}
posted @ 2023-12-27 20:03  HappyBobb  阅读(5)  评论(0编辑  收藏  举报  来源