[题解][Codeforces]Codeforces Round #635 (Div. 1) 简要题解

  • Chinese Round 果然对中国选手十分友好(

  • 原题解

A

题意

  • 给定一棵 \(n\) 个节点的有根树和一个 \(k\),满足 \(1\le k\le n\)

  • 选出 \(k\) 个点为黑点,其他点为白点

  • 求所有黑点到根的路径上白点个数之和的最大值

  • \(1\le n\le 2\times 10^5\)

做法:贪心

  • 显然一个点为黑点则其子树全为黑点

  • 故问题可以视为 \(k\) 次,每次删掉一个叶子 \(u\),贡献为原树\(dep_u-size_u\)

  • 由于父亲的 \(dep-size\) 一定小于子节点,故取 \(dep-size\) 从大到小排序之后前 \(k\) 大的即可

  • \(O(n\log n)\)

  • 利用 nth_element 可以做到 O(n)

代码

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

typedef long long ll;

const int N = 2e5 + 5, M = N << 1;

int n, k, ecnt, nxt[M], adj[N], go[M], dep[N], fa[N], d[N], sze[N], a[N];
ll ans;

void add_edge(int u, int v)
{
	nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
	nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}

void dfs(int u, int fu)
{
	fa[u] = fu; dep[u] = dep[fu] + 1; sze[u] = 1;
	for (int e = adj[u], v; e; e = nxt[e])
		if ((v = go[e]) != fu) dfs(v, u), d[u]++, sze[u] += sze[v];
}

int main()
{
	int x, y;
	read(n); read(k);
	for (int i = 1; i < n; i++) read(x), read(y), add_edge(x, y);
	dfs(1, 0);
	for (int i = 1; i <= n; i++) a[i] = dep[i] - sze[i];
	std::sort(a + 1, a + n + 1);
	for (int i = n - k + 1; i <= n; i++) ans += a[i];
	return std::cout << ans << std::endl, 0;
}

B

题意

  • 给定三个长度分别为 \(n_r,n_g,n_b\) 的数组 \(r,g,b\)

  • 从三个数组中各选一个数,设为 \(x,y,z\),求 \((x-y)^2+(y-z)^2+(z-x)^2\) 的最小值

  • \(1\le n_r,n_g,n_b\le 10^5\)\(1\le r_i,g_i,b_i\le 10^9\)

做法:枚举+双指针

  • 假设 \(x\le y\le z\),则最优情况下 \(x\) 要尽可能大,\(y\) 要尽可能小

  • 故把三个数组排序,枚举 \(x,y,z\) 大小关系的 \(6\) 种排列之后,枚举 \(y\) 的值,用指针维护最大的 \(x\) 和最小的 \(z\)

  • \(O(n_r\log n_r+n_g\log n_g+n_b\log n_b)\)

代码

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

typedef long long ll;

const int N = 1e5 + 5;
const ll INF = 5e18;

int nr, ng, nb, r[N], g[N], b[N];

ll sqr(int x) {return 1ll * x * x;}

ll solve(int na, int nb, int nc, int *a, int *b, int *c)
{
	ll ans = INF;
	for (int i = 1, j = 1, k = 1; j <= nb; j++)
	{
		while (i <= na && a[i] <= b[j]) i++;
		while (k <= nc && b[j] > c[k]) k++;
		if (i > 1 && k <= nc) ans = std::min(ans,
			sqr(a[i - 1] - b[j]) + sqr(b[j] - c[k]) + sqr(c[k] - a[i - 1]));
	}
	return ans;
}

void work()
{
	read(nr); read(ng); read(nb);
	for (int i = 1; i <= nr; i++) read(r[i]);
	for (int i = 1; i <= ng; i++) read(g[i]);
	for (int i = 1; i <= nb; i++) read(b[i]);
	std::sort(r + 1, r + nr + 1); std::sort(g + 1, g + ng + 1);
	std::sort(b + 1, b + nb + 1);
	ll ans = solve(nr, ng, nb, r, g, b);
	ans = std::min(ans, solve(nr, nb, ng, r, b, g));
	ans = std::min(ans, solve(nb, nr, ng, b, r, g));
	ans = std::min(ans, solve(nb, ng, nr, b, g, r));
	ans = std::min(ans, solve(ng, nr, nb, g, r, b));
	ans = std::min(ans, solve(ng, nb, nr, g, b, r));
	printf("%lld\n", ans);
}

int main()
{
	int T; read(T);
	while (T--) work();
	return 0;
}

C

题意

  • 给定长度为 \(n\) 的串 \(S\) 和长度为 \(m\) 的串 \(T\)

  • 一开始有一个空串 \(A\)

  • 每次操作可以选择把 \(S\) 的第一个字符加入 \(A\) 的开头或末尾,并把 \(S\) 的第一个字符删掉

  • 你可以执行任意不超过 \(n\) 的操作次数,求最后能使得 \(T\)\(A\) 的前缀的方案数,对 \(998244353\) 取模

  • \(1\le m\le n\le 3000\)

做法:区间 DP

  • \(f[l,r]\) 表示插入了 \(S\) 的前 \(r-l+1\) 个字符,它们组成了最终的 \(A\) 串的区间 \([l,r]\) 的方案数

  • 组成最终的 \(A\) 串的区间 \([l,r]\),也就是说若 \(i\in[l,r]\)\(i\le m\),则 \(A_i=T_i\)

  • 转移即枚举最后一个字符加在左边还是右边,判断其是否符合限制条件即可

  • 答案为 \(\sum_{i=m}^nf[1,i]\)

  • \(O(n^2)\)

代码

#include <bits/stdc++.h>

const int N = 3005, djq = 998244353;

int n, m, f[N][N], ans;
char s[N], t[N];

int main()
{
	scanf("%s%s", s + 1, t + 1);
	n = strlen(s + 1); m = strlen(t + 1);
	for (int i = 1; i <= n + 1; i++) f[i][i - 1] = 1;
	for (int l = n; l >= 1; l--)
		for (int r = l; r <= n; r++)
		{
			if (l > m || s[r - l + 1] == t[l]) f[l][r] += f[l + 1][r];
			if (r > m || s[r - l + 1] == t[r]) f[l][r] += f[l][r - 1];
			if (f[l][r] >= djq) f[l][r] -= djq;
			if (l == 1 && r >= m)
				ans = (ans + f[l][r]) % djq;
		}
	return std::cout << ans << std::endl, 0;
}

D

题意

  • 交互题

  • 你有一堆麻将,点数从 \(1\)\(n\),每种点数的麻将个数在 \([0,n]\) 之间,但你不知道它们具体是多少

  • 初始时可以知道这堆麻将中,碰(大小为 \(3\) 且点数相同的子集)的个数和吃(大小为 \(3\) 且点数形成公差为 \(1\) 的等差数列)的个数

  • 然后你可以加入最多 \(n\) 次某一种点数的麻将,加入一个麻将之后你可以得到此时碰和吃的个数

  • 还原初始时每种点数的麻将个数

  • \(4\le n\le 100\)

做法:数学

  • 当前\(i\) 种麻将有 \(c_i\) 个,则加入一个第 \(i\) 种麻将时会多出 \(\binom{c_i}2\) 个碰和 \(c_{i-2}c_{i-1}+c_{i-1}c_{i+1}+c_{i+1}c_{i+2}\) 个吃

  • 如果只考虑吃的个数,则如果保证 \(c_i>0\) 则可以通过碰的个数的增量还原出 \(c_i\)

  • 考虑求点数为 \(1\) 的个数,可以得到如果事先加入一个 \(1\),就能保证 \(c_i>0\),再加入一个 \(1\) 即可查出 \(ans_1\)

  • 而加入 \(1\) 的好处是吃的个数增量为 \(c_2c_3\)

  • 于是考虑依次加入 \(3,1,2,1\),这样第二次吃的个数增量为 \(ans_2(ans_3+1)\),第四次吃的个数增量为 \((ans_2+1)(ans_3+1)\)

  • 这两个式子作差即可得到 \(ans_3\)。由于 \(ans_3+1>0\),故可以使用除法得到 \(ans_2\)

  • 而实际上我们也可以得到 \(ans_4\):考虑第三次吃的个数增量:\((ans_3+1)(ans_1+1+ans_4)\),也可以利用除法得到

  • 而对于 \(i>4\),也可以加入一个 \(i-2\),这时吃的个数增量表达式中只有 \(ans_i\) 是未知量,可以解出来。不过这样有一个问题:\(ans_{i-1}\) 可能为 \(0\),这样的方程会有无穷多个解

  • 故考虑倒着加:\(n-1,n-2,\dots,3,1,2,1\)

  • 易得 \(3,1,2,1\) 移到最后不影响 \(ans_{1\dots 4}\) 的求解,只是 \(n>4\) 时这样求解出来的 \(ans_4\) 需要减 \(1\)(在 \(n-1,n-2,\dots 4\) 中加上了 \(1\)

  • 然后 \(i\)\(3\)\(n-2\),利用 \(i\) 被加入时吃的个数增量来解出 \(ans_{i+2}\),由于 \(i+1\) 在之前的过程中加过了 \(1\),故可以保证 \(c_{i+1}\) 不为 \(0\),这个方程一定可以解出来

  • \(O(n)\),操作次数为 \(n\)

代码

#include <bits/stdc++.h>

const int N = 110, M = N * N;

int n, ans[N], f[M], a[N], b[N];

void add(int v) {printf("+ %d\n", v); fflush(stdout);}

int main()
{
	scanf("%d", &n);
	for (int i = 1; i <= n + 1; i++) f[i * (i - 1) >> 1] = i;
	scanf("%*d%*d");
	for (int i = 1; i <= n - 4; i++) add(n - i), scanf("%d%d", &a[i], &b[i]);
	add(3); scanf("%d%d", &a[n - 3], &b[n - 3]);
	add(1); scanf("%d%d", &a[n - 2], &b[n - 2]);
	add(2); scanf("%d%d", &a[n - 1], &b[n - 1]);
	add(1); scanf("%d%d", &a[n], &b[n]);
	ans[1] = f[a[n] - a[n - 1]] - 1;
	ans[3] = (b[n] - b[n - 1]) - (b[n - 2] - b[n - 3]) - 1;
	ans[2] = (b[n] - b[n - 1]) / (ans[3] + 1) - 1;
	ans[4] = (b[n - 1] - b[n - 2]) / (ans[3] + 1) - (ans[1] + 1) - (n > 4);
	for (int i = n - 3; i >= 2; i--)
	{
		int x = n - i;
		ans[x + 2] = (b[i] - b[i - 1] - ans[x - 2] * ans[x - 1] - ans[x - 1]
			* (ans[x + 1] + 1)) / (ans[x + 1] + 1) - (i > 2);
	}
	printf("! ");
	for (int i = 1; i <= n; i++) printf("%d ", ans[i]);
	return puts(""), 0;
}

E1

题意

  • 给定 \(n\)\([0,2^m)\) 内的数

  • 对于所有的 \(0\le i\le m\),求这些数有多少个子集的异或和,二进制下 \(1\) 的个数为 \(i\)

  • \(1\le n\le 2\times10^5\)\(0\le m\le 35\)

做法:线性基+枚举(\(k\) 较小)/DP(\(k\) 较大)

  • 由于 E2 比 E1 难太太太多,就分开讲了

  • 显然先求线性基,设这个基由 \(k\) 个元素组成

  • 原一个子集的异或和可以表示成线性基内一个子集的异或和,再选上线性基外的一部分 \(0\),也就是线性基内一个子集的贡献为 \(2^{n-k}\)

  • \(k\) 较小的时候,可以暴力枚举每个基变量是否选上:\(O(2^k)\)

  • \(k\) 较大的时候,可以高斯消元求出简化阶梯矩阵(若矩阵第 \(i\) 行第 \(i\) 列为 \(1\) 则第 \(i\) 列的其他元素均为 \(0\)),然后 DP \(f_{i,j,S}\) 表示前 \(i\) 个基变量中选出了 \(j\) 个,不在基上的位异或和为 \(S\) 的方案数,统计答案时答案 \(ans_{j+popcount(S)}+=f_{m-k,j,S}\)\(O(2^{m-k}k^2)\)

  • 结合这两种算法可过 E1

代码

#include <bits/stdc++.h>
 
template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}
 
typedef long long ll;
 
const int N = 2e5 + 5, E = 40, C = 17000, djq = 998244353;
 
int n, m, orz = 1, cnt1, p1[N], cnt0, p0[N], f[E][E][C], st[E], ans[E];
ll a[N], b[E];
 
void ins(ll x)
{
	for (int i = m - 1; i >= 0; i--)
	{
		if (!((x >> i) & 1)) continue;
		if (b[i] == -1) return (void) (b[i] = x);
		else x ^= b[i];
	}
	orz = (orz << 1) % djq;
}
 
int cc(ll x)
{
	int res = 0;
	while (x) res += x & 1, x >>= 1;
	return res;
}
 
int main()
{
	read(n); read(m);
	for (int i = 0; i < m; i++) b[i] = -1;
	for (int i = 1; i <= n; i++) read(a[i]), ins(a[i]);
	for (int i = 0; i < m; i++) if (b[i] != -1)
		for (int j = i + 1; j < m; j++)
			if (b[j] != -1 && ((b[j] >> i) & 1)) b[j] ^= b[i];
	for (int i = 0; i < m; i++)
		if (b[i] != -1) p1[++cnt1] = i;
		else p0[++cnt0] = i;
	if (cnt1 <= 20)
	{
		for (int S = 0; S < (1 << cnt1); S++)
		{
			ll T = 0;
			for (int i = 1; i <= cnt1; i++)
				if ((S >> i - 1) & 1) T ^= b[p1[i]];
			ans[cc(T)]++;
		}
	}
	else
	{
		for (int i = 1; i <= cnt1; i++)
			for (int j = 1; j <= cnt0; j++)
				if ((b[p1[i]] >> p0[j]) & 1) st[i] |= 1 << j - 1;
		f[0][0][0] = 1;
		for (int i = 0; i < cnt1; i++)
			for (int j = 0; j <= i; j++)
				for (int S = 0; S < (1 << cnt0); S++)
				{
					f[i + 1][j][S] = (f[i + 1][j][S] + f[i][j][S]) % djq;
					f[i + 1][j + 1][S ^ st[i + 1]] = (f[i + 1][j + 1][S ^ st[i + 1]]
						+ f[i][j][S]) % djq;
				}
		for (int j = 0; j <= cnt1; j++)
			for (int S = 0; S < (1 << cnt0); S++)
			{
				int x = j + cc(S);
				ans[x] = (ans[x] + f[cnt1][j][S]) % djq;
			}
	}
	for (int i = 0; i <= m; i++) printf("%d ", 1ll * ans[i] * orz % djq);
	puts("");
	return 0;
}

E2

题意

  • 同 E1,\(0\le m\le 53\)

做法:FWT+组合数学

  • 妙啊!!!\(\times 4\)

  • 考虑对于 E1 的第二种算法,把复杂度去掉两个 \(k\)

  • \(A_S\) 表示 \(S\) 是否能被线性基表出,\(F^c_S\) 表示 \(S\)\(1\) 的个数是否为 \(c\)

  • 我们不难 (neng) 想到 \(ans_c\) 等于 \(FWT(A)\times FWT(F^c)\) 所有项之和(这里的 \(\times\) 是点乘)除以 \(2^m\) 后的结果(因为要做 IFWT)

  • 接下来考虑 \(FWT(A)\) 的性质

\(FWT(A)\) 仅由 \(0\)\(2^k\) 组成,且第 \(S\) 位为 \(2^k\) 当且仅当 \(S\) 与线性基内所有变量的交集大小都是偶数

  • 证明:

\(S\) 与所有基变量的交集大小都是偶数,由于 \(S\)\(T\bigoplus U\) 的交集大小在奇偶性上等于 \(S\cap T\)\(S\cap U\) 的大小之和,故 \(S\) 与这个基表出的所有 \(2^k\) 个数的交集大小都为偶数,由 FWT 的定义可知 \(FWT(A)\) 的第 \(S\) 位为 \(2^k\)
否则 \(S\) 与这个基表出的所有 \(2^k\) 个数的交集大小中奇偶各占一半,由 FWT 的定义可知 \(FWT(A)\) 的第 \(S\) 位为 \(0\)

另一个性质:

\(FWT(A)\) 中为 \(2^k\) 的位只有 \(2^{m-k}\) 个,且组成另一个基

  • 证明:

\(FWT(A)\) 中第 \(S\) 位为 \(2^k\) 的条件转化一下:对于一个不在基上的位 \(i\),如果让第 \(i\) 位为 \(1\),则对于每个满足第 \(i\) 位为 \(1\) 的基变量 \(j\),要让 \(S\) 的第 \(j\) 位也异或上 \(1\)
这样就有了 \(m-k\) 个基变量,由于每个基变量的最低位互不相同,故它们可以组成一个基
但原线性基必须是简化阶梯矩阵,否则在基上的位 \(i\) 也会对其他在基上的位 \(j\) 造成影响

  • 于是求出这个大小为 \(m-k\) 的基后暴力枚举每个变量选或不选,即可得到 \(FWT(A)\) 中所有为 \(2^k\) 的位

  • 再考虑 \(FWT(F^c\)),容易发现 \(FWT(F^c)\) 的第 \(S\) 位值只和 \(S\)\(1\) 的个数有关

  • 即对于 \(S\),枚举一个 \(1\) 的个数为 \(c\)\(T\) 贡献 \((-1)^{|S\cap T|}\),相当于枚举一个 \(i\) 表示 \(S\)\(T\) 表示 \(S\)\(T\) 的交集大小

  • 于是 \(FWT(F^c)\) 包含 \(d\)\(1\) 的位值均为:

  • \[w_{c,d}=\sum_{i=0}^{\min(c,d)}(-1)^i\binom di\binom{m-d}{c-i} \]

  • \(FWT(A)\) 中含 \(c\)\(1\) 的下标有 \(q_c\)\(2^k\),则:

  • \[ans_c=\frac 1{2^{m-k}}\sum_{d=0}^mq_dw_{c,d} \]

  • 结合 \(k\) 较小的暴力枚举,复杂度为 \(O(2^{\frac m2}+m^3+n)\)

代码

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

typedef long long ll;

const int N = 60, djq = 998244353, i2 = 499122177;

int n, m, orz = 1, cnt1, p[N], cnt0, cnt[N], ans[N], C[N][N];
ll b[N], a[N];

void ins(ll x)
{
	for (int i = m - 1; i >= 0; i--)
	{
		if (!((x >> i) & 1)) continue;
		if (b[i] == -1) return (void) (b[i] = x);
		else x ^= b[i];
	}
	orz = (orz << 1) % djq;
}

void dfs(int dep, int tar, ll T)
{
	if (dep == tar + 1) return (void) (ans[__builtin_popcountll(T)]++);
	dfs(dep + 1, tar, T); dfs(dep + 1, tar, T ^ a[dep]);
}

int main()
{
	ll x;
	read(n); read(m);
	for (int i = 0; i < m; i++) b[i] = -1;
	for (int i = 1; i <= n; i++) read(x), ins(x);
	for (int i = 0; i < m; i++) if (b[i] != -1)
		for (int j = i + 1; j < m; j++)
			if (b[j] != -1 && ((b[j] >> i) & 1)) b[j] ^= b[i];
	for (int i = 0; i < m; i++) if (b[i] != -1) a[++cnt1] = b[i];
	if (cnt1 <= 26) dfs(1, cnt1, 0);
	else
	{
		for (int i = 0; i < m; i++) if (b[i] == -1)
		{
			a[++cnt0] = 1ll << i;
			for (int j = i + 1; j < m; j++) if (b[j] != -1 && ((b[j] >> i) & 1))
				a[cnt0] |= 1ll << j;
		}
		dfs(1, cnt0, 0);
		for (int i = 0; i <= m; i++) cnt[i] = ans[i], ans[i] = 0, C[i][0] = 1;
		for (int i = 1; i <= m; i++)
			for (int j = 1; j <= i; j++)
				C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % djq;
		int I = 1;
		for (int i = 1; i <= cnt0; i++) I = 1ll * I * i2 % djq;
		for (int i = 0; i <= m; i++)
			for (int j = 0; j <= m; j++)
			{
				int pl = 0;
				for (int k = 0; k <= j && k <= i; k++)
				{
					int delta = 1ll * C[j][k] * C[m - j][i - k] % djq;
					if (k & 1) pl = (pl - delta + djq) % djq;
					else pl = (pl + delta) % djq;
				}
				ans[i] = (1ll * I * pl % djq * cnt[j] + ans[i]) % djq;
			}
	}
	for (int i = 0; i <= m; i++) printf("%d ", 1ll * ans[i] * orz % djq);
	return puts(""), 0;
}

F

题意

  • 给定 \(n\) 个节点的树,\(m\) 条路径和一个 \(k\)

  • 求有多少对路径的交至少包含 \(k\) 条边

  • \(2\le n,m\le 1.5\times10^5\)\(1\le k\le n\)

做法:分类讨论+倍增+BIT+线段树

  • 任选一个根,先考虑相交的两条路径 LCA 不同的情况

  • 此时可以把一条路径拆成两条(\(s_i\)\(lca_i\)\(t_i\)\(lca_i\))来看待

  • 下面设拆完之后的路径为 \((up_i,down_i)\)\(up_i\) 的深度较小

  • 考虑当 \(dep_{up_i}<dep_{up_j}\) 时,第 \(i\) 条和第 \(j\) 条路径交集至少为 \(k\) 当且仅当 \(up_j\) 沿着 \(down_j\) 的方向走 \(k\) 步之后还在路径 \((down_i,up_i)\)

  • 用倍增处理出每个 \(up_i\) 沿着 \(down_i\) 的方向走 \(k\) 步之后到达的点,用 DFS序+差分+BIT 进行单点加和路径查询即可

  • 再考虑 LCA 相同的情况,设这个 LCA 为 \(u\),这时又分两种:

  • (1)设对于所有的 \(i\) 都有 \(s_i\) 的 DFS 序小于 \(t_i\),则 \(s_i\)\(s_j\) 都不为 \(u\) 且在 \(u\) 的同一棵子树内,\(t_i\)\(t_j\) 也一样

  • (2)反之

  • 先考虑(2),设路径 \(i\)\((x_i,u)\) 部分和路径 \(j\)\((x_j,u)\) 部分有交集(\(x_i,x_j\) 为路径 \(i,j\) 的端点之一)

  • 同样地,这相当于 \(u\) 沿着 \(x_i\) 向下走 \(k\) 步和沿着 \(x_j\) 向下走 \(k\) 步到达的点相同,也可以拆成两条之后用和之前类似的方法处理

  • 而对于(1),考虑 \(v=lca(s_i,s_j)\),方案合法当且仅当:

  • (1)\(u\)\(v\) 的严格祖先

  • (2)\(dep_v-dep_u\ge k\)\(v\) 朝着 \(t_i\)\(dep_v-dep_u+1\) 步之后的节点子树内包含 \(t_j\)

  • (3)\(dep_v-dep_u<k\)\(v\) 朝着 \(t_i\)\(k\) 步之后的节点子树内包含 \(t_j\)

  • 这三个条件中(1)满足且(2)(3)满足一者

  • 如果 \(i\) 的取值集合和 \(j\) 的取值集合给定(不交),则可以建立 \(n\) 棵动态开点线段树,维护每个 LCA 的路径的 \(t\)

  • 把所有 \(j\) 插入到第 \(lca_j\) 棵线段树的 \(dfn_{t_j}\) 位置之后,对于每个 \(i\) 查询第 \(lca_i\) 棵线段树上某个节点的子树和即可

  • 回到原问题,可以 dsu-on-tree:对这棵树每个非叶节点找出一个 preferred child(即设 \(cnt_u=\sum_i[s_i=u]\),preferred child 为 \(cnt_u\) 的和最大的子树),然后 dfs 的过程中,先递归轻儿子并把线段树上的东西清掉,然后递归重儿子,这时不要把线段树上的东西清掉,把重子树以外的所有路径的 \(s\) 加入并统计答案

  • 期间可用一个 set 维护当前子树内的所有路径

  • \(O(m\log^2m+n\log n)\)

  • 本题的巧妙之处就在于,使用了从交点处移动 \(k\) 步的方法,来判断两条路径的交长度是否 \(\ge k\)

代码

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

typedef long long ll;
typedef std::set<int>::iterator it;

const int N = 15e4 + 5, M = N << 1, L = 1e7 + 5, E = 20;

int n, m, k, ecnt, nxt[M], adj[N], go[M], times, dfn[N], dep[N], fa[N][E],
s[N], t[N], l[N], p[N], A[N], sze[N], cnt[N], son[N], rt[N], ToT, top, stk[M];
ll ans;
std::set<int> orz[N];
std::vector<int> a[N], b[N];

struct node
{
	int lc, rc, sum;
} T[L];

void change(int l, int r, int pos, int v, int &p)
{
	if (!p) p = ++ToT; T[p].sum += v;
	if (l == r) return;
	int mid = l + r >> 1;
	if (pos <= mid) change(l, mid, pos, v, T[p].lc);
	else change(mid + 1, r, pos, v, T[p].rc);
}

int ask(int l, int r, int s, int e, int p)
{
	if (!p || e < l || s > r) return 0;
	if (s <= l && r <= e) return T[p].sum;
	int mid = l + r >> 1;
	return ask(l, mid, s, e, T[p].lc) + ask(mid + 1, r, s, e, T[p].rc);
}

void change(int x, int v)
{
	for (; x <= n; x += x & -x)
		A[x] += v;
}

void sub(int u) {change(dfn[u], 1); change(dfn[u] + sze[u], -1);}

int ask(int x)
{
	int res = 0;
	for (; x; x -= x & -x) res += A[x];
	return res;
}

inline bool comp(int a, int b)
{
	return dep[l[a]] > dep[l[b]] || (dep[l[a]] == dep[l[b]] && l[a] < l[b]);
}

void add_edge(int u, int v)
{
	nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
	nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}

void dfs(int u, int fu)
{
	dep[u] = dep[fa[u][0] = fu] + (sze[u] = 1);
	for (int i = 0; i < 17; i++) fa[u][i + 1] = fa[fa[u][i]][i];
	dfn[u] = ++times;
	for (int e = adj[u], v; e; e = nxt[e])
		if ((v = go[e]) != fu) dfs(v, u), sze[u] += sze[v];
}

int lca(int u, int v)
{
	if (dep[u] < dep[v]) std::swap(u, v);
	for (int i = 17; i >= 0; i--)
		if (dep[fa[u][i]] >= dep[v])
			u = fa[u][i];
	if (u == v) return u;
	for (int i = 17; i >= 0; i--)
		if (fa[u][i] != fa[v][i])
			u = fa[u][i], v = fa[v][i];
	return fa[u][0];
}

int J(int u, int k)
{
	for (int i = 17; i >= 0; i--)
		if ((k >> i) & 1) u = fa[u][i];
	return u;
}

void init(int u, int fu)
{
	int mx = -1;
	for (int e = adj[u], v; e; e = nxt[e])
		if ((v = go[e]) != fu)
		{
			init(v, u); cnt[u] += cnt[v];
			if (cnt[v] > mx) mx = cnt[v], son[u] = v;
		}
}

void wtf(int u, int i)
{
	if (dfn[l[i]] >= dfn[u] || dfn[u] >= dfn[l[i]] + sze[l[i]]) return;
	int len = dep[u] + dep[t[i]] - dep[l[i]] * 2;
	if (len < k || t[i] == l[i]) return;
	int v = dep[u] - dep[l[i]] >= k ? J(t[i], dep[t[i]] - dep[l[i]] - 1)
		: J(t[i], len - k);
	ans += ask(1, n, dfn[v], dfn[v] + sze[v] - 1, rt[l[i]]);
	if (ask(1, n, dfn[v], dfn[v] + sze[v] - 1, rt[l[i]]));
}

void DFS(int u, int fu)
{
	for (int e = adj[u], v; e; e = nxt[e])
		if ((v = go[e]) != fu && v != son[u])
		{
			DFS(v, u);
			for (it x = orz[v].begin(); x != orz[v].end(); x++)
				change(1, n, dfn[t[*x]], -1, rt[l[*x]]);
		}
	if (son[u]) DFS(son[u], u);
	for (it x = orz[u].begin(); x != orz[u].end(); x++)
		wtf(u, *x), change(1, n, dfn[t[*x]], 1, rt[l[*x]]);
	if (son[u])
	{
		for (int e = adj[u], v; e; e = nxt[e])
		{
			if ((v = go[e]) == fu || v == son[u]) continue;
			for (it x = orz[v].begin(); x != orz[v].end(); x++) wtf(u, *x);
			for (it x = orz[v].begin(); x != orz[v].end(); x++)
				change(1, n, dfn[t[*x]], 1, rt[l[*x]]), orz[son[u]].insert(*x);
		}
		for (it x = orz[u].begin(); x != orz[u].end(); x++)
			orz[son[u]].insert(*x);
		std::swap(orz[u], orz[son[u]]);
	}
}

int main()
{
	int x, y;
	read(n); read(m); read(k);
	for (int i = 1; i < n; i++) read(x), read(y), add_edge(x, y);
	dfs(1, 0);
	for (int i = 1; i <= m; i++)
	{
		read(s[i]); read(t[i]);
		if (dfn[s[i]] > dfn[t[i]]) std::swap(s[i], t[i]);
		l[i] = lca(s[i], t[i]); p[i] = i;
		orz[s[i]].insert(i); cnt[s[i]]++; a[l[i]].push_back(i);
	}
	std::sort(p + 1, p + m + 1, comp);
	for (int i = 1; i <= m;)
	{
		int nxt = i;
		while (nxt <= m && l[p[i]] == l[p[nxt]]) nxt++;
		for (int j = i; j < nxt; j++)
		{
			int x = p[j], u = s[x], v = t[x], w = l[x];
			ans += ask(dfn[u]) + ask(dfn[v]) - ask(dfn[w]) * 2;
		}
		for (int j = i; j < nxt; j++)
		{
			int x = p[j], u = s[x], v = t[x], w = l[x];
			if (dep[u] - dep[w] >= k) sub(J(u, dep[u] - dep[w] - k));
			if (dep[v] - dep[w] >= k) sub(J(v, dep[v] - dep[w] - k));
		}
		i = nxt;
	}
	memset(A, 0, sizeof(A));
	for (int u = 1; u <= n; u++)
	{
		for (int i = 0; i < a[u].size(); i++)
		{
			int x = a[u][i];
			if (dep[s[x]] - dep[u] >= k)
			{
				ans += A[y = J(s[x], dep[s[x]] - dep[u] - k)]++; stk[++top] = y;
				if (t[x] != u) b[J(t[x], dep[t[x]] - dep[u] - 1)].push_back(y);
			}
			if (dep[t[x]] - dep[u] >= k)
			{
				ans += A[y = J(t[x], dep[t[x]] - dep[u] - k)]++; stk[++top] = y;
				if (s[x] != u) b[J(s[x], dep[s[x]] - dep[u] - 1)].push_back(y);
			}
		}
		while (top--) A[stk[top + 1]] = 0; top = 0;
		for (int e = adj[u], v; e; e = nxt[e])
		{
			if ((v = go[e]) == fa[u][0]) continue;
			for (int i = 0; i < b[v].size(); i++)
				ans -= A[y = b[v][i]]++, stk[++top] = y;
			while (top--) A[stk[top + 1]] = 0; top = 0;
		}
	}
	init(1, 0); DFS(1, 0);
	return std::cout << ans << std::endl, 0;
}
posted @ 2020-04-20 15:43  epic01  阅读(279)  评论(0编辑  收藏  举报