CodeForces 1856E2 PermuTree (hard version)
考虑局部贪心,假设我们现在在 \(u\),我们希望 \(u\) 不同子树中的 \((v, w), a_v < a_u < a_w\) 的对数尽量多。
我们实际上只关心子树内 \(a_u\) 的相对大小关系,不关心它们具体是什么。如果 \(u\) 只有两个儿子 \(v, w\),我们可以让 \(v\) 子树内的 \(a\) 全部小于 \(w\) 子树内的 \(a\),这样 \(u\) 作为 \(\text{LCA}\) 的贡献是 \(sz_v \times sz_w\),是最大的。
那么对于 \(u\) 有多个儿子的情况,推广可知相当于把 \(u\) 的儿子分成 \(S, T\) 两个集合,最大化 \(\sum\limits_{v \in S} sz_v \times \sum\limits_{v \in T} sz_v\)。考虑做一个 \(sz_v\) 的 01 背包,若能把 \(sz_v\) 分成大小为 \(x\) 的集合,\(u\) 对答案的贡献是 \(x \times (sz_u - 1 - x)\)。取这个的最大值即可。
01 背包暴力做即可,根据树形背包的那套理论,每个点对只会在 \(\text{LCA}\) 处被统计,所以时间复杂度 \(O(n^2)\),可以通过 E1。
对于 E2,我们肯定不能再暴力 01 背包了。发现我我们背包的复杂度跟 \(sz_v\) 有关。联想到 dsu on tree,轻子树的大小之和为 \(O(n \log n)\)。于是我们考虑将 \(u\) 的 \(sz\) 最大的两个儿子拎出来,剩下的儿子做一个背包,然后再枚举那两个儿子选不选。
至于如何做背包,我们把 \(sz_v\) 相同的物品看做一种有多个的物品,做单调队列优化多重背包即可。因为去掉两个最大子树后,\(sz_v\) 之和为 \(n \log n\),所以不同的 \(sz_v\) 有 \(O(\sqrt{n \log n})\) 种。
所以这么算下来复杂度其实是 \(O(n \sqrt{n \log n} \log n)\),但是它过了???
code
// Problem: E2. PermuTree (hard version)
// Contest: Codeforces - Codeforces Round 890 (Div. 2) supported by Constructor Institute
// URL: https://codeforces.com/contest/1856/problem/E2
// Memory Limit: 512 MB
// Time Limit: 3000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int read() {
char c = getchar();
int x = 0;
for (; !isdigit(c); c = getchar()) ;
for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
return x;
}
const int maxn = 1000100;
int n, sz[maxn];
bool f[maxn], g[maxn];
ll ans;
vector<int> G[maxn];
void dfs(int u) {
sz[u] = 1;
vector<int> vc;
for (int v : G[u]) {
dfs(v);
sz[u] += sz[v];
vc.pb(sz[v]);
}
int m = (int)vc.size();
if (m <= 2) {
ll mx = 0;
for (int S = 0; S < (1 << m); ++S) {
int s = 0;
for (int i = 0; i < m; ++i) {
if (S & (1 << i)) {
s += vc[i];
}
}
mx = max(mx, 1LL * s * (sz[u] - 1 - s));
}
ans += mx;
return;
}
sort(vc.begin(), vc.end(), greater<int>());
int s = 0;
for (int i = 2; i < m; ++i) {
s += vc[i];
}
s /= 2;
for (int i = 0; i <= s; ++i) {
f[i] = 0;
}
f[0] = 1;
for (int l = 2, r = 2; l < m; l = (++r)) {
while (r + 1 < m && vc[r + 1] == vc[l]) {
++r;
}
for (int i = 0; i <= s; ++i) {
g[i] = f[i];
f[i] = 0;
}
int c = r - l + 1, v = vc[l];
for (int i = 0; i < v; ++i) {
int cnt = 0, t = i;
for (int j = i, k = 0; j <= s; j += v, ++k) {
cnt += g[j];
if (k > c) {
cnt -= g[t];
t += v;
}
f[j] = (cnt ? 1 : 0);
}
}
}
ll mx = 0;
for (int i = 0; i <= s; ++i) {
if (!f[i]) {
continue;
}
for (int S = 0; S < 4; ++S) {
int k = ((S & 1) ? vc[0] : 0) + ((S & 2) ? vc[1] : 0) + i;
mx = max(mx, 1LL * k * (sz[u] - 1 - k));
}
}
ans += mx;
}
void solve() {
n = read();
for (int i = 2, p; i <= n; ++i) {
p = read();
G[p].pb(i);
}
dfs(1);
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}