[题解][Codeforces]Codeforces Round #635 (Div. 1) 简要题解
-
Chinese Round 果然对中国选手十分友好(
A
题意
-
给定一棵 \(n\) 个节点的有根树和一个 \(k\),满足 \(1\le k\le n\)
-
选出 \(k\) 个点为黑点,其他点为白点
-
求所有黑点到根的路径上白点个数之和的最大值
-
\(1\le n\le 2\times 10^5\)
做法:贪心
-
显然一个点为黑点则其子树全为黑点
-
故问题可以视为 \(k\) 次,每次删掉一个叶子 \(u\),贡献为原树的 \(dep_u-size_u\)
-
由于父亲的 \(dep-size\) 一定小于子节点,故取 \(dep-size\) 从大到小排序之后前 \(k\) 大的即可
-
\(O(n\log n)\)
-
利用 nth_element 可以做到 O(n)
代码
#include <bits/stdc++.h>
template <class T>
inline void read(T &res)
{
res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
if (bo) res = ~res + 1;
}
typedef long long ll;
const int N = 2e5 + 5, M = N << 1;
int n, k, ecnt, nxt[M], adj[N], go[M], dep[N], fa[N], d[N], sze[N], a[N];
ll ans;
void add_edge(int u, int v)
{
nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}
void dfs(int u, int fu)
{
fa[u] = fu; dep[u] = dep[fu] + 1; sze[u] = 1;
for (int e = adj[u], v; e; e = nxt[e])
if ((v = go[e]) != fu) dfs(v, u), d[u]++, sze[u] += sze[v];
}
int main()
{
int x, y;
read(n); read(k);
for (int i = 1; i < n; i++) read(x), read(y), add_edge(x, y);
dfs(1, 0);
for (int i = 1; i <= n; i++) a[i] = dep[i] - sze[i];
std::sort(a + 1, a + n + 1);
for (int i = n - k + 1; i <= n; i++) ans += a[i];
return std::cout << ans << std::endl, 0;
}
B
题意
-
给定三个长度分别为 \(n_r,n_g,n_b\) 的数组 \(r,g,b\)
-
从三个数组中各选一个数,设为 \(x,y,z\),求 \((x-y)^2+(y-z)^2+(z-x)^2\) 的最小值
-
\(1\le n_r,n_g,n_b\le 10^5\),\(1\le r_i,g_i,b_i\le 10^9\)
做法:枚举+双指针
-
假设 \(x\le y\le z\),则最优情况下 \(x\) 要尽可能大,\(y\) 要尽可能小
-
故把三个数组排序,枚举 \(x,y,z\) 大小关系的 \(6\) 种排列之后,枚举 \(y\) 的值,用指针维护最大的 \(x\) 和最小的 \(z\)
-
\(O(n_r\log n_r+n_g\log n_g+n_b\log n_b)\)
代码
#include <bits/stdc++.h>
template <class T>
inline void read(T &res)
{
res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
if (bo) res = ~res + 1;
}
typedef long long ll;
const int N = 1e5 + 5;
const ll INF = 5e18;
int nr, ng, nb, r[N], g[N], b[N];
ll sqr(int x) {return 1ll * x * x;}
ll solve(int na, int nb, int nc, int *a, int *b, int *c)
{
ll ans = INF;
for (int i = 1, j = 1, k = 1; j <= nb; j++)
{
while (i <= na && a[i] <= b[j]) i++;
while (k <= nc && b[j] > c[k]) k++;
if (i > 1 && k <= nc) ans = std::min(ans,
sqr(a[i - 1] - b[j]) + sqr(b[j] - c[k]) + sqr(c[k] - a[i - 1]));
}
return ans;
}
void work()
{
read(nr); read(ng); read(nb);
for (int i = 1; i <= nr; i++) read(r[i]);
for (int i = 1; i <= ng; i++) read(g[i]);
for (int i = 1; i <= nb; i++) read(b[i]);
std::sort(r + 1, r + nr + 1); std::sort(g + 1, g + ng + 1);
std::sort(b + 1, b + nb + 1);
ll ans = solve(nr, ng, nb, r, g, b);
ans = std::min(ans, solve(nr, nb, ng, r, b, g));
ans = std::min(ans, solve(nb, nr, ng, b, r, g));
ans = std::min(ans, solve(nb, ng, nr, b, g, r));
ans = std::min(ans, solve(ng, nr, nb, g, r, b));
ans = std::min(ans, solve(ng, nb, nr, g, b, r));
printf("%lld\n", ans);
}
int main()
{
int T; read(T);
while (T--) work();
return 0;
}
C
题意
-
给定长度为 \(n\) 的串 \(S\) 和长度为 \(m\) 的串 \(T\)
-
一开始有一个空串 \(A\)
-
每次操作可以选择把 \(S\) 的第一个字符加入 \(A\) 的开头或末尾,并把 \(S\) 的第一个字符删掉
-
你可以执行任意不超过 \(n\) 的操作次数,求最后能使得 \(T\) 是 \(A\) 的前缀的方案数,对 \(998244353\) 取模
-
\(1\le m\le n\le 3000\)
做法:区间 DP
-
\(f[l,r]\) 表示插入了 \(S\) 的前 \(r-l+1\) 个字符,它们组成了最终的 \(A\) 串的区间 \([l,r]\) 的方案数
-
组成最终的 \(A\) 串的区间 \([l,r]\),也就是说若 \(i\in[l,r]\) 且 \(i\le m\),则 \(A_i=T_i\)
-
转移即枚举最后一个字符加在左边还是右边,判断其是否符合限制条件即可
-
答案为 \(\sum_{i=m}^nf[1,i]\)
-
\(O(n^2)\)
代码
#include <bits/stdc++.h>
const int N = 3005, djq = 998244353;
int n, m, f[N][N], ans;
char s[N], t[N];
int main()
{
scanf("%s%s", s + 1, t + 1);
n = strlen(s + 1); m = strlen(t + 1);
for (int i = 1; i <= n + 1; i++) f[i][i - 1] = 1;
for (int l = n; l >= 1; l--)
for (int r = l; r <= n; r++)
{
if (l > m || s[r - l + 1] == t[l]) f[l][r] += f[l + 1][r];
if (r > m || s[r - l + 1] == t[r]) f[l][r] += f[l][r - 1];
if (f[l][r] >= djq) f[l][r] -= djq;
if (l == 1 && r >= m)
ans = (ans + f[l][r]) % djq;
}
return std::cout << ans << std::endl, 0;
}
D
题意
-
交互题
-
你有一堆麻将,点数从 \(1\) 到 \(n\),每种点数的麻将个数在 \([0,n]\) 之间,但你不知道它们具体是多少
-
初始时可以知道这堆麻将中,碰(大小为 \(3\) 且点数相同的子集)的个数和吃(大小为 \(3\) 且点数形成公差为 \(1\) 的等差数列)的个数
-
然后你可以加入最多 \(n\) 次某一种点数的麻将,加入一个麻将之后你可以得到此时碰和吃的个数
-
还原初始时每种点数的麻将个数
-
\(4\le n\le 100\)
做法:数学
-
设当前第 \(i\) 种麻将有 \(c_i\) 个,则加入一个第 \(i\) 种麻将时会多出 \(\binom{c_i}2\) 个碰和 \(c_{i-2}c_{i-1}+c_{i-1}c_{i+1}+c_{i+1}c_{i+2}\) 个吃
-
如果只考虑吃的个数,则如果保证 \(c_i>0\) 则可以通过碰的个数的增量还原出 \(c_i\)
-
考虑求点数为 \(1\) 的个数,可以得到如果事先加入一个 \(1\),就能保证 \(c_i>0\),再加入一个 \(1\) 即可查出 \(ans_1\)
-
而加入 \(1\) 的好处是吃的个数增量为 \(c_2c_3\)
-
于是考虑依次加入 \(3,1,2,1\),这样第二次吃的个数增量为 \(ans_2(ans_3+1)\),第四次吃的个数增量为 \((ans_2+1)(ans_3+1)\)
-
这两个式子作差即可得到 \(ans_3\)。由于 \(ans_3+1>0\),故可以使用除法得到 \(ans_2\)
-
而实际上我们也可以得到 \(ans_4\):考虑第三次吃的个数增量:\((ans_3+1)(ans_1+1+ans_4)\),也可以利用除法得到
-
而对于 \(i>4\),也可以加入一个 \(i-2\),这时吃的个数增量表达式中只有 \(ans_i\) 是未知量,可以解出来。不过这样有一个问题:\(ans_{i-1}\) 可能为 \(0\),这样的方程会有无穷多个解
-
故考虑倒着加:\(n-1,n-2,\dots,3,1,2,1\)
-
易得 \(3,1,2,1\) 移到最后不影响 \(ans_{1\dots 4}\) 的求解,只是 \(n>4\) 时这样求解出来的 \(ans_4\) 需要减 \(1\)(在 \(n-1,n-2,\dots 4\) 中加上了 \(1\))
-
然后 \(i\) 从 \(3\) 到 \(n-2\),利用 \(i\) 被加入时吃的个数增量来解出 \(ans_{i+2}\),由于 \(i+1\) 在之前的过程中加过了 \(1\),故可以保证 \(c_{i+1}\) 不为 \(0\),这个方程一定可以解出来
-
\(O(n)\),操作次数为 \(n\)
代码
#include <bits/stdc++.h>
const int N = 110, M = N * N;
int n, ans[N], f[M], a[N], b[N];
void add(int v) {printf("+ %d\n", v); fflush(stdout);}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n + 1; i++) f[i * (i - 1) >> 1] = i;
scanf("%*d%*d");
for (int i = 1; i <= n - 4; i++) add(n - i), scanf("%d%d", &a[i], &b[i]);
add(3); scanf("%d%d", &a[n - 3], &b[n - 3]);
add(1); scanf("%d%d", &a[n - 2], &b[n - 2]);
add(2); scanf("%d%d", &a[n - 1], &b[n - 1]);
add(1); scanf("%d%d", &a[n], &b[n]);
ans[1] = f[a[n] - a[n - 1]] - 1;
ans[3] = (b[n] - b[n - 1]) - (b[n - 2] - b[n - 3]) - 1;
ans[2] = (b[n] - b[n - 1]) / (ans[3] + 1) - 1;
ans[4] = (b[n - 1] - b[n - 2]) / (ans[3] + 1) - (ans[1] + 1) - (n > 4);
for (int i = n - 3; i >= 2; i--)
{
int x = n - i;
ans[x + 2] = (b[i] - b[i - 1] - ans[x - 2] * ans[x - 1] - ans[x - 1]
* (ans[x + 1] + 1)) / (ans[x + 1] + 1) - (i > 2);
}
printf("! ");
for (int i = 1; i <= n; i++) printf("%d ", ans[i]);
return puts(""), 0;
}
E1
题意
-
给定 \(n\) 个 \([0,2^m)\) 内的数
-
对于所有的 \(0\le i\le m\),求这些数有多少个子集的异或和,二进制下 \(1\) 的个数为 \(i\)
-
\(1\le n\le 2\times10^5\),\(0\le m\le 35\)
做法:线性基+枚举(\(k\) 较小)/DP(\(k\) 较大)
-
由于 E2 比 E1 难太太太多,就分开讲了
-
显然先求线性基,设这个基由 \(k\) 个元素组成
-
原一个子集的异或和可以表示成线性基内一个子集的异或和,再选上线性基外的一部分 \(0\),也就是线性基内一个子集的贡献为 \(2^{n-k}\)
-
\(k\) 较小的时候,可以暴力枚举每个基变量是否选上:\(O(2^k)\)
-
\(k\) 较大的时候,可以高斯消元求出简化阶梯矩阵(若矩阵第 \(i\) 行第 \(i\) 列为 \(1\) 则第 \(i\) 列的其他元素均为 \(0\)),然后 DP \(f_{i,j,S}\) 表示前 \(i\) 个基变量中选出了 \(j\) 个,不在基上的位异或和为 \(S\) 的方案数,统计答案时答案 \(ans_{j+popcount(S)}+=f_{m-k,j,S}\):\(O(2^{m-k}k^2)\)
-
结合这两种算法可过 E1
代码
#include <bits/stdc++.h>
template <class T>
inline void read(T &res)
{
res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
if (bo) res = ~res + 1;
}
typedef long long ll;
const int N = 2e5 + 5, E = 40, C = 17000, djq = 998244353;
int n, m, orz = 1, cnt1, p1[N], cnt0, p0[N], f[E][E][C], st[E], ans[E];
ll a[N], b[E];
void ins(ll x)
{
for (int i = m - 1; i >= 0; i--)
{
if (!((x >> i) & 1)) continue;
if (b[i] == -1) return (void) (b[i] = x);
else x ^= b[i];
}
orz = (orz << 1) % djq;
}
int cc(ll x)
{
int res = 0;
while (x) res += x & 1, x >>= 1;
return res;
}
int main()
{
read(n); read(m);
for (int i = 0; i < m; i++) b[i] = -1;
for (int i = 1; i <= n; i++) read(a[i]), ins(a[i]);
for (int i = 0; i < m; i++) if (b[i] != -1)
for (int j = i + 1; j < m; j++)
if (b[j] != -1 && ((b[j] >> i) & 1)) b[j] ^= b[i];
for (int i = 0; i < m; i++)
if (b[i] != -1) p1[++cnt1] = i;
else p0[++cnt0] = i;
if (cnt1 <= 20)
{
for (int S = 0; S < (1 << cnt1); S++)
{
ll T = 0;
for (int i = 1; i <= cnt1; i++)
if ((S >> i - 1) & 1) T ^= b[p1[i]];
ans[cc(T)]++;
}
}
else
{
for (int i = 1; i <= cnt1; i++)
for (int j = 1; j <= cnt0; j++)
if ((b[p1[i]] >> p0[j]) & 1) st[i] |= 1 << j - 1;
f[0][0][0] = 1;
for (int i = 0; i < cnt1; i++)
for (int j = 0; j <= i; j++)
for (int S = 0; S < (1 << cnt0); S++)
{
f[i + 1][j][S] = (f[i + 1][j][S] + f[i][j][S]) % djq;
f[i + 1][j + 1][S ^ st[i + 1]] = (f[i + 1][j + 1][S ^ st[i + 1]]
+ f[i][j][S]) % djq;
}
for (int j = 0; j <= cnt1; j++)
for (int S = 0; S < (1 << cnt0); S++)
{
int x = j + cc(S);
ans[x] = (ans[x] + f[cnt1][j][S]) % djq;
}
}
for (int i = 0; i <= m; i++) printf("%d ", 1ll * ans[i] * orz % djq);
puts("");
return 0;
}
E2
题意
- 同 E1,\(0\le m\le 53\)
做法:FWT+组合数学
-
妙啊!!!\(\times 4\)
-
考虑对于 E1 的第二种算法,把复杂度去掉两个 \(k\)
-
设 \(A_S\) 表示 \(S\) 是否能被线性基表出,\(F^c_S\) 表示 \(S\) 中 \(1\) 的个数是否为 \(c\)
-
我们不难 (neng) 想到 \(ans_c\) 等于 \(FWT(A)\times FWT(F^c)\) 所有项之和(这里的 \(\times\) 是点乘)除以 \(2^m\) 后的结果(因为要做 IFWT)
-
接下来考虑 \(FWT(A)\) 的性质
\(FWT(A)\) 仅由 \(0\) 和 \(2^k\) 组成,且第 \(S\) 位为 \(2^k\) 当且仅当 \(S\) 与线性基内所有变量的交集大小都是偶数
- 证明:
若 \(S\) 与所有基变量的交集大小都是偶数,由于 \(S\) 与 \(T\bigoplus U\) 的交集大小在奇偶性上等于 \(S\cap T\) 与 \(S\cap U\) 的大小之和,故 \(S\) 与这个基表出的所有 \(2^k\) 个数的交集大小都为偶数,由 FWT 的定义可知 \(FWT(A)\) 的第 \(S\) 位为 \(2^k\)
否则 \(S\) 与这个基表出的所有 \(2^k\) 个数的交集大小中奇偶各占一半,由 FWT 的定义可知 \(FWT(A)\) 的第 \(S\) 位为 \(0\)
另一个性质:
\(FWT(A)\) 中为 \(2^k\) 的位只有 \(2^{m-k}\) 个,且组成另一个基
- 证明:
把 \(FWT(A)\) 中第 \(S\) 位为 \(2^k\) 的条件转化一下:对于一个不在基上的位 \(i\),如果让第 \(i\) 位为 \(1\),则对于每个满足第 \(i\) 位为 \(1\) 的基变量 \(j\),要让 \(S\) 的第 \(j\) 位也异或上 \(1\)
这样就有了 \(m-k\) 个基变量,由于每个基变量的最低位互不相同,故它们可以组成一个基
但原线性基必须是简化阶梯矩阵,否则在基上的位 \(i\) 也会对其他在基上的位 \(j\) 造成影响
-
于是求出这个大小为 \(m-k\) 的基后暴力枚举每个变量选或不选,即可得到 \(FWT(A)\) 中所有为 \(2^k\) 的位
-
再考虑 \(FWT(F^c\)),容易发现 \(FWT(F^c)\) 的第 \(S\) 位值只和 \(S\) 中 \(1\) 的个数有关
-
即对于 \(S\),枚举一个 \(1\) 的个数为 \(c\) 的 \(T\) 贡献 \((-1)^{|S\cap T|}\),相当于枚举一个 \(i\) 表示 \(S\) 和 \(T\) 表示 \(S\) 和 \(T\) 的交集大小
-
于是 \(FWT(F^c)\) 包含 \(d\) 个 \(1\) 的位值均为:
-
\[w_{c,d}=\sum_{i=0}^{\min(c,d)}(-1)^i\binom di\binom{m-d}{c-i} \]
-
设 \(FWT(A)\) 中含 \(c\) 个 \(1\) 的下标有 \(q_c\) 个 \(2^k\),则:
-
\[ans_c=\frac 1{2^{m-k}}\sum_{d=0}^mq_dw_{c,d} \]
-
结合 \(k\) 较小的暴力枚举,复杂度为 \(O(2^{\frac m2}+m^3+n)\)
代码
#include <bits/stdc++.h>
template <class T>
inline void read(T &res)
{
res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
if (bo) res = ~res + 1;
}
typedef long long ll;
const int N = 60, djq = 998244353, i2 = 499122177;
int n, m, orz = 1, cnt1, p[N], cnt0, cnt[N], ans[N], C[N][N];
ll b[N], a[N];
void ins(ll x)
{
for (int i = m - 1; i >= 0; i--)
{
if (!((x >> i) & 1)) continue;
if (b[i] == -1) return (void) (b[i] = x);
else x ^= b[i];
}
orz = (orz << 1) % djq;
}
void dfs(int dep, int tar, ll T)
{
if (dep == tar + 1) return (void) (ans[__builtin_popcountll(T)]++);
dfs(dep + 1, tar, T); dfs(dep + 1, tar, T ^ a[dep]);
}
int main()
{
ll x;
read(n); read(m);
for (int i = 0; i < m; i++) b[i] = -1;
for (int i = 1; i <= n; i++) read(x), ins(x);
for (int i = 0; i < m; i++) if (b[i] != -1)
for (int j = i + 1; j < m; j++)
if (b[j] != -1 && ((b[j] >> i) & 1)) b[j] ^= b[i];
for (int i = 0; i < m; i++) if (b[i] != -1) a[++cnt1] = b[i];
if (cnt1 <= 26) dfs(1, cnt1, 0);
else
{
for (int i = 0; i < m; i++) if (b[i] == -1)
{
a[++cnt0] = 1ll << i;
for (int j = i + 1; j < m; j++) if (b[j] != -1 && ((b[j] >> i) & 1))
a[cnt0] |= 1ll << j;
}
dfs(1, cnt0, 0);
for (int i = 0; i <= m; i++) cnt[i] = ans[i], ans[i] = 0, C[i][0] = 1;
for (int i = 1; i <= m; i++)
for (int j = 1; j <= i; j++)
C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % djq;
int I = 1;
for (int i = 1; i <= cnt0; i++) I = 1ll * I * i2 % djq;
for (int i = 0; i <= m; i++)
for (int j = 0; j <= m; j++)
{
int pl = 0;
for (int k = 0; k <= j && k <= i; k++)
{
int delta = 1ll * C[j][k] * C[m - j][i - k] % djq;
if (k & 1) pl = (pl - delta + djq) % djq;
else pl = (pl + delta) % djq;
}
ans[i] = (1ll * I * pl % djq * cnt[j] + ans[i]) % djq;
}
}
for (int i = 0; i <= m; i++) printf("%d ", 1ll * ans[i] * orz % djq);
return puts(""), 0;
}
F
题意
-
给定 \(n\) 个节点的树,\(m\) 条路径和一个 \(k\)
-
求有多少对路径的交至少包含 \(k\) 条边
-
\(2\le n,m\le 1.5\times10^5\),\(1\le k\le n\)
做法:分类讨论+倍增+BIT+线段树
-
任选一个根,先考虑相交的两条路径 LCA 不同的情况
-
此时可以把一条路径拆成两条(\(s_i\) 到 \(lca_i\) 和 \(t_i\) 到 \(lca_i\))来看待
-
下面设拆完之后的路径为 \((up_i,down_i)\),\(up_i\) 的深度较小
-
考虑当 \(dep_{up_i}<dep_{up_j}\) 时,第 \(i\) 条和第 \(j\) 条路径交集至少为 \(k\) 当且仅当 \(up_j\) 沿着 \(down_j\) 的方向走 \(k\) 步之后还在路径 \((down_i,up_i)\) 上
-
用倍增处理出每个 \(up_i\) 沿着 \(down_i\) 的方向走 \(k\) 步之后到达的点,用
DFS序+差分+BIT
进行单点加和路径查询即可 -
再考虑 LCA 相同的情况,设这个 LCA 为 \(u\),这时又分两种:
-
(1)设对于所有的 \(i\) 都有 \(s_i\) 的 DFS 序小于 \(t_i\),则 \(s_i\) 和 \(s_j\) 都不为 \(u\) 且在 \(u\) 的同一棵子树内,\(t_i\) 和 \(t_j\) 也一样
-
(2)反之
-
先考虑(2),设路径 \(i\) 的 \((x_i,u)\) 部分和路径 \(j\) 的 \((x_j,u)\) 部分有交集(\(x_i,x_j\) 为路径 \(i,j\) 的端点之一)
-
同样地,这相当于 \(u\) 沿着 \(x_i\) 向下走 \(k\) 步和沿着 \(x_j\) 向下走 \(k\) 步到达的点相同,也可以拆成两条之后用和之前类似的方法处理
-
而对于(1),考虑 \(v=lca(s_i,s_j)\),方案合法当且仅当:
-
(1)\(u\) 是 \(v\) 的严格祖先
-
(2)\(dep_v-dep_u\ge k\) 且 \(v\) 朝着 \(t_i\) 走 \(dep_v-dep_u+1\) 步之后的节点子树内包含 \(t_j\)
-
(3)\(dep_v-dep_u<k\) 且 \(v\) 朝着 \(t_i\) 走 \(k\) 步之后的节点子树内包含 \(t_j\)
-
这三个条件中(1)满足且(2)(3)满足一者
-
如果 \(i\) 的取值集合和 \(j\) 的取值集合给定(不交),则可以建立 \(n\) 棵动态开点线段树,维护每个 LCA 的路径的 \(t\)
-
把所有 \(j\) 插入到第 \(lca_j\) 棵线段树的 \(dfn_{t_j}\) 位置之后,对于每个 \(i\) 查询第 \(lca_i\) 棵线段树上某个节点的子树和即可
-
回到原问题,可以 dsu-on-tree:对这棵树每个非叶节点找出一个 preferred child(即设 \(cnt_u=\sum_i[s_i=u]\),preferred child 为 \(cnt_u\) 的和最大的子树),然后 dfs 的过程中,先递归轻儿子并把线段树上的东西清掉,然后递归重儿子,这时不要把线段树上的东西清掉,把重子树以外的所有路径的 \(s\) 加入并统计答案
-
期间可用一个
set
维护当前子树内的所有路径 -
\(O(m\log^2m+n\log n)\)
-
本题的巧妙之处就在于,使用了从交点处移动 \(k\) 步的方法,来判断两条路径的交长度是否 \(\ge k\)
代码
#include <bits/stdc++.h>
template <class T>
inline void read(T &res)
{
res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
if (bo) res = ~res + 1;
}
typedef long long ll;
typedef std::set<int>::iterator it;
const int N = 15e4 + 5, M = N << 1, L = 1e7 + 5, E = 20;
int n, m, k, ecnt, nxt[M], adj[N], go[M], times, dfn[N], dep[N], fa[N][E],
s[N], t[N], l[N], p[N], A[N], sze[N], cnt[N], son[N], rt[N], ToT, top, stk[M];
ll ans;
std::set<int> orz[N];
std::vector<int> a[N], b[N];
struct node
{
int lc, rc, sum;
} T[L];
void change(int l, int r, int pos, int v, int &p)
{
if (!p) p = ++ToT; T[p].sum += v;
if (l == r) return;
int mid = l + r >> 1;
if (pos <= mid) change(l, mid, pos, v, T[p].lc);
else change(mid + 1, r, pos, v, T[p].rc);
}
int ask(int l, int r, int s, int e, int p)
{
if (!p || e < l || s > r) return 0;
if (s <= l && r <= e) return T[p].sum;
int mid = l + r >> 1;
return ask(l, mid, s, e, T[p].lc) + ask(mid + 1, r, s, e, T[p].rc);
}
void change(int x, int v)
{
for (; x <= n; x += x & -x)
A[x] += v;
}
void sub(int u) {change(dfn[u], 1); change(dfn[u] + sze[u], -1);}
int ask(int x)
{
int res = 0;
for (; x; x -= x & -x) res += A[x];
return res;
}
inline bool comp(int a, int b)
{
return dep[l[a]] > dep[l[b]] || (dep[l[a]] == dep[l[b]] && l[a] < l[b]);
}
void add_edge(int u, int v)
{
nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}
void dfs(int u, int fu)
{
dep[u] = dep[fa[u][0] = fu] + (sze[u] = 1);
for (int i = 0; i < 17; i++) fa[u][i + 1] = fa[fa[u][i]][i];
dfn[u] = ++times;
for (int e = adj[u], v; e; e = nxt[e])
if ((v = go[e]) != fu) dfs(v, u), sze[u] += sze[v];
}
int lca(int u, int v)
{
if (dep[u] < dep[v]) std::swap(u, v);
for (int i = 17; i >= 0; i--)
if (dep[fa[u][i]] >= dep[v])
u = fa[u][i];
if (u == v) return u;
for (int i = 17; i >= 0; i--)
if (fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
int J(int u, int k)
{
for (int i = 17; i >= 0; i--)
if ((k >> i) & 1) u = fa[u][i];
return u;
}
void init(int u, int fu)
{
int mx = -1;
for (int e = adj[u], v; e; e = nxt[e])
if ((v = go[e]) != fu)
{
init(v, u); cnt[u] += cnt[v];
if (cnt[v] > mx) mx = cnt[v], son[u] = v;
}
}
void wtf(int u, int i)
{
if (dfn[l[i]] >= dfn[u] || dfn[u] >= dfn[l[i]] + sze[l[i]]) return;
int len = dep[u] + dep[t[i]] - dep[l[i]] * 2;
if (len < k || t[i] == l[i]) return;
int v = dep[u] - dep[l[i]] >= k ? J(t[i], dep[t[i]] - dep[l[i]] - 1)
: J(t[i], len - k);
ans += ask(1, n, dfn[v], dfn[v] + sze[v] - 1, rt[l[i]]);
if (ask(1, n, dfn[v], dfn[v] + sze[v] - 1, rt[l[i]]));
}
void DFS(int u, int fu)
{
for (int e = adj[u], v; e; e = nxt[e])
if ((v = go[e]) != fu && v != son[u])
{
DFS(v, u);
for (it x = orz[v].begin(); x != orz[v].end(); x++)
change(1, n, dfn[t[*x]], -1, rt[l[*x]]);
}
if (son[u]) DFS(son[u], u);
for (it x = orz[u].begin(); x != orz[u].end(); x++)
wtf(u, *x), change(1, n, dfn[t[*x]], 1, rt[l[*x]]);
if (son[u])
{
for (int e = adj[u], v; e; e = nxt[e])
{
if ((v = go[e]) == fu || v == son[u]) continue;
for (it x = orz[v].begin(); x != orz[v].end(); x++) wtf(u, *x);
for (it x = orz[v].begin(); x != orz[v].end(); x++)
change(1, n, dfn[t[*x]], 1, rt[l[*x]]), orz[son[u]].insert(*x);
}
for (it x = orz[u].begin(); x != orz[u].end(); x++)
orz[son[u]].insert(*x);
std::swap(orz[u], orz[son[u]]);
}
}
int main()
{
int x, y;
read(n); read(m); read(k);
for (int i = 1; i < n; i++) read(x), read(y), add_edge(x, y);
dfs(1, 0);
for (int i = 1; i <= m; i++)
{
read(s[i]); read(t[i]);
if (dfn[s[i]] > dfn[t[i]]) std::swap(s[i], t[i]);
l[i] = lca(s[i], t[i]); p[i] = i;
orz[s[i]].insert(i); cnt[s[i]]++; a[l[i]].push_back(i);
}
std::sort(p + 1, p + m + 1, comp);
for (int i = 1; i <= m;)
{
int nxt = i;
while (nxt <= m && l[p[i]] == l[p[nxt]]) nxt++;
for (int j = i; j < nxt; j++)
{
int x = p[j], u = s[x], v = t[x], w = l[x];
ans += ask(dfn[u]) + ask(dfn[v]) - ask(dfn[w]) * 2;
}
for (int j = i; j < nxt; j++)
{
int x = p[j], u = s[x], v = t[x], w = l[x];
if (dep[u] - dep[w] >= k) sub(J(u, dep[u] - dep[w] - k));
if (dep[v] - dep[w] >= k) sub(J(v, dep[v] - dep[w] - k));
}
i = nxt;
}
memset(A, 0, sizeof(A));
for (int u = 1; u <= n; u++)
{
for (int i = 0; i < a[u].size(); i++)
{
int x = a[u][i];
if (dep[s[x]] - dep[u] >= k)
{
ans += A[y = J(s[x], dep[s[x]] - dep[u] - k)]++; stk[++top] = y;
if (t[x] != u) b[J(t[x], dep[t[x]] - dep[u] - 1)].push_back(y);
}
if (dep[t[x]] - dep[u] >= k)
{
ans += A[y = J(t[x], dep[t[x]] - dep[u] - k)]++; stk[++top] = y;
if (s[x] != u) b[J(s[x], dep[s[x]] - dep[u] - 1)].push_back(y);
}
}
while (top--) A[stk[top + 1]] = 0; top = 0;
for (int e = adj[u], v; e; e = nxt[e])
{
if ((v = go[e]) == fa[u][0]) continue;
for (int i = 0; i < b[v].size(); i++)
ans -= A[y = b[v][i]]++, stk[++top] = y;
while (top--) A[stk[top + 1]] = 0; top = 0;
}
}
init(1, 0); DFS(1, 0);
return std::cout << ans << std::endl, 0;
}