SGU507. Treediff 题解 树上启发式合并 $O(n \log^2 n)$ 解法(TLE)
题目链接:https://codeforces.com/problemsets/acmsguru/problem/99999/507
题目大意:
每个叶子节点有一个权值。求所有非叶子节点所在的子树中所有叶子节点的权值的差的绝对值的最小值。
解题思路:
树上启发式合并。
关于求绝对值的差的最小值的解法我是使用 multiset 的,参见 这篇随笔
然后 multiset 的 \(O( \log n )\) 乘上 dsu on tree 的复杂度 \(O(n \log n)\),总的时间复杂度为 \(O(n \log^2 n)\)。
示例代码(虽然目前还是 TLE 的):
#include <bits/stdc++.h>
using namespace std;
const int maxn = 50050;
multiset<int> num_set, diff_set;
void my_add(int x) {
multiset<int>::iterator it = num_set.lower_bound(x);
int num1 = -1, num2 = -1;
if (it != num_set.end()) {
if (it != num_set.end()) num1 = *it;
diff_set.insert(num1 - x);
}
if (it != num_set.begin()) {
it --;
num2 = *it;
diff_set.insert(x - num2);
}
if (num1 != -1 && num2 != -1) {
it = diff_set.lower_bound(num1 - num2);
diff_set.erase(it);
}
num_set.insert(x);
}
void my_del(int x) {
multiset<int>::iterator it = num_set.lower_bound(x), it2;
assert(it != num_set.end() && (*it) == x);
int num1 = -1, num2 = -1;
it ++;
if (it != num_set.end()) {
num1 = *it;
it2 = diff_set.lower_bound(num1 - x);
diff_set.erase(it2);
}
it --;
if (it != num_set.begin()) {
it --;
num2 = *it;
it2 = diff_set.lower_bound(x - num2);
diff_set.erase(it2);
it ++;
}
if (num1 != -1 && num2 != -1) {
diff_set.insert(num1 - num2);
}
num_set.erase(it);
}
int func_find() {
int ans = INT_MAX;
if (num_set.size() < 2) return ans;
multiset<int>::iterator it = diff_set.begin();
return *it;
}
int n, m, sz[maxn], val[maxn], res[maxn];
bool big[maxn];
vector<int> g[maxn];
void getsz(int u) {
sz[u] ++;
for (auto v: g[u])
getsz(v), sz[u] += sz[v];
}
void add(int u) {
if (u >= n-m+1) {
assert(g[u].size() == 0);
my_add(val[u]);
}
else
for (auto v: g[u])
if (!big[v])
add(v);
}
void del(int u) {
if (u >= n-m+1) {
assert(g[u].size() == 0);
my_del(val[u]);
}
else
for (auto v: g[u])
if (!big[v])
del(v);
}
void dfs(int u, bool keep) {
int mx = -1, bigSon = -1;
for (auto v: g[u])
if (sz[v] > mx)
mx = sz[ bigSon = v ];
for (auto v: g[u])
if (v != bigSon)
dfs(v, false);
if (bigSon != -1)
dfs(bigSon, true),
big[bigSon] = true;
add(u);
res[u] = func_find();
if (bigSon != -1)
big[bigSon] = false;
if (!keep)
del(u);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 2; i <= n; i ++) {
int p;
scanf("%d", &p);
g[p].push_back(i);
}
for (int i = n-m+1; i <= n; i ++) scanf("%d", val+i);
getsz(1);
dfs(1, false);
for (int i = 1; i <= n-m; i ++) printf("%d ", res[i]);
return 0;
}