「解题报告」ARC125F Tree Degree Subset Sum
很神奇的题。
首先容易发现这个树是没什么用的,直接转成度数数组。然后这个度数数组可以是满足 \(\sum d_i = 2n - 2, d_i \ge 1\) 中的任意一个数组。
\(d_i \ge 1\) 这个限制很奇怪,我们考虑将所有的 \(d_i\) 减掉 \(1\),得到新的数组。此时有 \(\sum d_i = n - 2\)。设 \(z\) 为 \(0\) 的数量。
首先我们给出一个结论:对于每一种大小 \(v\),能够组成这个大小 \(v\) 的可能的物品数量是一段区间。
设能够组成大小 \(v\) 所需物品最多为 \(mx(v)\),最少为 \(mn(v)\),那么上述结论告诉我们答案为 \(\sum mx(v) - mn(v) + 1\)。
考虑证明。首先容易发现,对于最少的方案,肯定不会选取任何一个 \(0\);同理,最多的方案会选取所有的 \(0\)。那么我们只需要证明 \(mx(v) - mn(v) \le 2z\),就说明我们可以从最少的方案或最大的方案通过加减 \(0\) 取得所有方案。
这个也是容易证明的,对于任意一种选取方案 \(S\),其大小为 \(s\),和为 \(v\),那么我们可以分析出 \(-z \le v - s \le z - 2\)。\(v - s \ge -z\) 的原因很显然,我们只有 \(z\) 个 \(0\);假如我们有 \((v, s)\) 的方案,那么就一定有 \((n - 2 - v, n - s)\) 的方案,那么 \((n - 2 - v) - (n - s) \ge -z \to v - s \le z - 2\)。那么这就说明 \(-z \le v - mx(v) \le v - mn(v) \le z - 2\),就有 \(mx(v) - mn(v) \le 2z\) 了。
那么我们现在只需要求出 \(mn(v), mx(v)\) 了。\(\sum d_i = n - 2\) 这个限制是很经典的,根号分治可以分析出不同的度数只有 \(O(\sqrt{n})\) 种,那么我们实际上就是要做一个 \(O(\sqrt n)\) 个物品,\(O(n)\) 大小的多重背包,使用单调队列优化多重背包即可做到 \(O(n \sqrt n)\) 的复杂度。
能够有这些性质的原因大概是物品总大小小于总个数,所以性质很优秀,而一开始总大小 \(2n-2\) 并没有什么优秀的性质。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005;
int n, d[MAXN];
int m, w[MAXN], c[MAXN];
int cnt[MAXN];
int fmn[2][MAXN], gmn[2][MAXN];
int fmx[2][MAXN], gmx[2][MAXN];
int mx[MAXN], mn[MAXN];
int q[MAXN], h, t;
int main() {
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v; scanf("%d%d", &u, &v);
d[u]++, d[v]++;
}
for (int i = 1; i <= n; i++) {
d[i]--;
}
for (int i = 1; i <= n; i++) {
cnt[d[i]]++;
}
for (int i = 1; i <= n - 2; i++) {
if (cnt[i]) {
w[++m] = i, c[m] = cnt[i];
}
}
int o = 0;
memset(fmn, 0x3f, sizeof fmn);
fmn[o][0] = 0;
for (int i = 1; i <= m; i++) {
memset(gmn[o ^ 1], 0x3f, sizeof gmn[o ^ 1]);
for (int x = 0; x < w[i]; x++) {
h = 0, t = 1;
for (int y = 0; y <= (n - 2 - x) / w[i]; y++) {
int j = x + y * w[i];
gmn[o][j] = fmn[o][j] - y;
while (h >= t && q[t] / w[i] < y - c[i]) t++;
while (h >= t && gmn[o][q[h]] > gmn[o][j]) h--;
q[++h] = j;
gmn[o ^ 1][j] = gmn[o][q[t]];
fmn[o ^ 1][j] = gmn[o ^ 1][j] + y;
}
}
o ^= 1;
}
for (int j = 0; j <= n - 2; j++) {
mn[j] = fmn[o][j];
}
o = 0;
memset(fmx, 0xef, sizeof fmx);
fmx[o][0] = cnt[0];
for (int i = 1; i <= m; i++) {
memset(gmx[o ^ 1], 0xef, sizeof gmx[o ^ 1]);
for (int x = 0; x < w[i]; x++) {
h = 0, t = 1;
for (int y = 0; y <= (n - 2 - x) / w[i]; y++) {
int j = x + y * w[i];
gmx[o][j] = fmx[o][j] - y;
while (h >= t && q[t] / w[i] < y - c[i]) t++;
while (h >= t && gmx[o][q[h]] < gmx[o][j]) h--;
q[++h] = j;
gmx[o ^ 1][j] = gmx[o][q[t]];
fmx[o ^ 1][j] = gmx[o ^ 1][j] + y;
}
}
o ^= 1;
}
for (int j = 0; j <= n - 2; j++) {
mx[j] = fmx[o][j];
}
long long ans = 0;
for (int j = 0; j <= n - 2; j++) {
if (mx[j] >= mn[j]) {
ans += mx[j] - mn[j] + 1;
}
}
printf("%lld\n", ans);
return 0;
}