CF1034C Region Separation
CF1034C Region Separation
题目大意
给定一棵 \(n\) 个点的树。每个点有权值,记为 \(a_{1\dots n}\)。
你想砍树。你可以砍任意轮,每轮你选择一些边(至少一条)断开,需要满足每轮结束后每个连通块的权值和是相等的。
求有多少种砍树方案。两种方案不同,当且仅当轮数不同或者某一轮砍的边不同。答案对 \(10^9 + 7\) 取模。
数据范围:\(1\leq n\leq 10^6\),\(1\leq a_i\leq 10^9\)。
本题题解
取任意点为根(不妨取 \(1\)),把树变成有根树。设 \(S\) 为所有节点的点权和,\(s_i\) 表示以 \(i\) 为根的子树内的点权和。
考虑只进行一轮划分,有多少种方案。设把树划分为了 \(k\) 块,那么存在方案的一个必要条件是:\(k\) 是 \(S\) 的约数。此时每块内的点权和是 \(\frac{S}{k}\)。
我们从叶子开始,类似树形 DP 的过程,每当当前连通块的和达到 \(\frac{S}{k}\),就把当前连通块割下来。具体来说,对所有 \(i = 2\dots n\),如果 \(s_{i} \bmod \frac{S}{k} = 0\),就切断 \(i\) 和 \(\mathrm{fa}(i)\) 之间的边。容易发现,如果有解,我们这样会求出唯一的一组解。如果无解,按此方法最终得到的连通块数会小于 \(k\)。
通过如上讨论,我们已经可以回答【只进行一轮划分,把树划分为 \(k\) 块的方案数】,设为 \(f(k)\),则:
也就是说,\(f(k)\) 只能等于 \(0\) 或 \(1\),并且它为 \(1\) 当且仅当:\(k\) 是 \(S\) 的约数,且恰有 \(k\) 个 \(i\) 满足 \(s_i \bmod \frac{S}{k} = 0\)。
暴力求出 \(f(1)\dots f(n)\) 的时间复杂度是 \(\mathcal{O}(n^2)\) 的,我们想要更快。
考虑 \(s_i \bmod \frac{S}{k} = 0\) 这个要求。转化一下:
也就是说,这个要求等价于 \(k\) 是 \(\frac{S}{\gcd(S, s_i)}\) 的倍数。
那么对每个 \(s_i\),设 \(v = \frac{S}{\gcd(S, s_i)}\),它可以对 $k = v, 2v, 3v, \dots $ 产生贡献,另外注意 \(k \leq n\)。于是我们用桶存一下,然后累加一遍,就能在调和级数的时间复杂度内,求出所有 \(f(k)\) 了。时间复杂度 \(\mathcal{O}(n\log n)\)。
然后考虑不止划分一轮的问题。设 \(\mathrm{dp}(k)\) 表示划分了若干轮,得到 \(k\) 个连通块的方案数。根据划分方案的唯一性,它上一轮划分出的连通块数,能且仅能是 \(k\) 的约数。于是可以写出转移:
答案就是 \(\sum_{k = 1}^{n} \mathrm{dp}(k)\)。这个 DP 的时间复杂度也是调和级数: \(\mathcal{O}(n\log n)\) 的。
时间复杂度 \(\mathcal{O}(n(\log n + \log a))\),其中 \(\log a\) 来自求 \(\gcd\)。
参考代码
片段:
const int MAXN = 1e6;
const int MOD = 1e9 + 7;
int n, a[MAXN + 5], fa[MAXN + 5];
ll s[MAXN + 5], S;
int f[MAXN + 5], dp[MAXN + 5];
ll gcd(ll x, ll y) { return (!y) ? x : gcd(y, x % y); }
int main() {
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
s[i] = a[i];
S += a[i];
}
for (int i = 2; i <= n; ++i) {
cin >> fa[i];
}
for (int i = n; i >= 2; --i) {
s[fa[i]] += s[i];
}
assert(s[1] == S);
for (int i = 1; i <= n; ++i) {
ll v = S / gcd(S, s[i]);
if (v <= n) {
f[v]++;
}
}
for (int i = n; i >= 1; --i) {
for (int j = i + i; j <= n; j += i) {
f[j] += f[i];
}
}
for (int i = 1; i <= n; ++i) {
f[i] = (f[i] == i); // 恰好能划分为 i 块
}
dp[1] = 1;
int ans = 0;
for (int i = 1; i <= n; ++i) {
if (!f[i]) {
dp[i] = 0;
continue;
}
for (int j = i + i; j <= n; j += i) {
dp[j] = (dp[j] + dp[i]) % MOD;
}
ans = (ans + dp[i]) % MOD;
}
cout << ans << endl;
return 0;
}