CF871E Restore the Tree
CF871E Restore the Tree
题目大意
yzh 有一棵 \(n\) 个节点的妹子树。大给给 ztr 对 yzh 使用了男♂酮♂魔♂法,导致 yzh 把妹子树忘记了!
好在 yzh 有 \(k\) 个最喜欢的妹子(称为关键节点),他记得每个最喜欢的妹子到其它妹子的距离。
现在 yzh 告诉了你 \(n\), \(k\) 以及 \(k\times n\) 个数,表示每个关键节点到其它节点的树上距离。请你帮 yzh 求出任意一种符合要求的树,或告诉他不存在这样的树。
数据范围:\(1\leq n\leq 30000\),\(1\leq k\leq \min\{n, 200\}\)。
本题题解
首先我们可以确定 \(k\) 个关键节点的编号:如果有解,每个关键节点一定和恰好一个节点距离为 \(0\),这个节点就是它自己。记第 \(i\) 个关键节点的编号为 \(\mathrm{id}_i\)。
记 \(\mathrm{dis}(u, v)\) 表示 \(u, v\) 两点的树上距离。
不妨以第 \(1\) 个关键节点为根,记为 \(\mathrm{root} = \mathrm{id}_1\)。
对于其他每个关键节点 \(\mathrm{id}_i\),可以求出根到它的路径。具体来说,一个节点 \(v\) 在 \(\mathrm{root}\) 到 \(\mathrm{id}_i\) 的路径上,当且仅当 \(\mathrm{dis}(\mathrm{root}, v) + \mathrm{dis}(\mathrm{id}_i, v) = \mathrm{dis}(\mathrm{root}, \mathrm{id}_i)\)。对所有满足这样条件的 \(v\),按 \(\mathrm{dis}(\mathrm{root}, v)\) 从小到大排序,就可以得到从 \(\mathrm{root}\) 到 \(\mathrm{id}_i\) 的路径(注意,因为 \(\mathrm{dis}(\mathrm{root}, v) < n\),所以“排序”可以用桶排实现,时间复杂度是 \(\mathcal{O}(n)\) 的)。
对每个关键节点,求出根到它的路径后,类似于向 \(\text{trie}\) 里插入字符串一样,插入这条路径。至此,我们确定了树的结构的一部分:所有关键节点到根的路径的并(以下称这部分为“主体部分”)。这部分时间复杂度 \(\mathcal{O}(nk)\)。
考虑不在主体部分里的节点。我们已知它们的深度(也就是 \(\mathrm{root}\) 到它的距离)。按深度从小到大考虑这些节点,问题相当于要给每个节点确定一个父亲。
对节点 \(v\),考虑它的祖先里,深度最大的、主体部分里的节点,记为 \(u\)。对每个 \(v\),\(u\) 就是所有【关键节点和 \(v\) 的 LCA】中的深度最大者。求一个关键节点 \(\mathrm{id}_i\) 和 \(v\) 的 LCA(记为 \(l\)),可以考虑:\(\mathrm{dis}(\mathrm{root}, \mathrm{id}_i) + \mathrm{dis}(\mathrm{root}, v) - 2\cdot \mathrm{dis}(\mathrm{root}, l) = \mathrm{dis}(\mathrm{id}_i, v)\),解出 \(\mathrm{dis}(\mathrm{root}, l) = x\),那么 \(\mathrm{root}\) 到 \(\mathrm{id}_i\) 的路径上第 \(x\) 个点就是 \(l\) 了。
确定了 \(u\) 后,同时也可以知道 \(u\) 到 \(v\) 的距离。若 \(u\) 就是 \(v\) 的父亲,则已经完成任务(确定了 \(v\) 的父亲)。否则找出任意一个 \(u\) 子树里、不在主体部分中的、距离 \(u\) 为 \(\mathrm{dis}(u, v) - 1\) 的节点,以它作为 \(v\) 的父亲即可。因为我们是按深度从小到大考虑不在主体部分中的节点的,所以 \(v\) 的父亲一定已经被加入过。
枚举 \(\mathcal{O}(n)\) 个 \(v\),\(\mathcal{O}(k)\) 求出对应的 \(u\),所以这部分时间复杂度也是 \(\mathcal{O}(nk)\)。
至此,我们在 \(\mathcal{O}(nk)\) 的时间内解决了本题。
参考代码
以下代码来自 CF 官方题解,比较清晰易懂。
#include <bits/stdc++.h>
using namespace std;
void impossible() {
puts("-1");
exit(0);
}
int main() {
int n, k;
assert(scanf("%d %d", &n, &k) == 2);
assert(k > 0);
vector <int> dist[k];
vector <int> idx(n, 0);
for (int i = 0; i < k; ++i) {
dist[i] = vector <int> (n, 0);
for (int j = 0; j < n; ++j)
assert(scanf("%d", &dist[i][j]) == 1);
idx[i] = -1;
for (int j = 0; j < n; ++j) {
if (dist[i][j] == 0) {
if (idx[i] != -1)
impossible();
else
idx[i] = j;
}
}
if (idx[i] == -1)
impossible();
for (int j = 0; j < i; ++j)
if (idx[i] == idx[j])
impossible();
}
vector <int> p(n, -1);
vector <int> h = dist[0];
int root = idx[0];
vector <int> ps[k];
for (int i = 0; i < k; ++i)
ps[i] = vector <int> ();
for (int t = 1; t < k; ++t) {
int v = idx[t];
vector <int> &d = dist[t];
vector <int> pars(n, -1);
for (int u = 0; u < n; ++u) {
if (h[u] + d[u] == h[v]) {
if (pars[h[u]] != -1)
impossible();
pars[h[u]] = u;
}
}
for (int i = 0; i < n; ++i) {
if (pars[i] == -1 && i <= h[v])
impossible();
if (pars[i] != -1 && i > h[v])
impossible();
}
for (int i = 0; i < h[v]; ++i) {
int pu = pars[i];
int u = pars[i + 1];
if (p[u] != -1 && p[u] != pu)
impossible();
p[u] = pu;
}
ps[t] = pars;
}
if (p[root] != -1)
impossible();
vector <int> subs[n];
for (int i = 0; i < n; ++i)
subs[i] = vector <int> ();
vector <int> vh[n];
for (int i = 0; i < n; ++i)
vh[i] = vector <int> ();
for (int v = 0; v < n; ++v)
vh[h[v]].push_back(v);
for (int hh = 0; hh < n; ++hh) {
for (int v: vh[hh]) {
if (v == root || p[v] != -1)
continue;
int lp = root;
int lt = 0;
for (int t = 1; t < k; ++t) {
int u = idx[t];
vector <int> &d = dist[t];
vector <int> &pars = ps[t];
// h[v] + h[u] - 2 * h[l] == d[v]
// 2 * h[l] == h[v] + h[u] - d[v]
int hl2 = h[v] + h[u] - d[v];
if (hl2 % 2 != 0)
impossible();
int hl = hl2 / 2;
if (hl < 0 || hl > h[u] || hl > h[v])
impossible();
int l = pars[hl];
if (h[l] >= h[lp]) {
if (pars[h[lp]] != lp)
impossible();
lp = l;
lt = t;
} else {
if (ps[lt][h[l]] != l)
impossible();
}
}
subs[lp].push_back(v);
}
}
for (int i = 0; i < n; ++i) {
int prev = i;
int nxt = -1;
for (int v: subs[i]) {
assert(prev != -1);
if (h[v] == h[prev] + 1) {
p[v] = prev;
nxt = v;
} else {
prev = nxt;
if (prev != -1 && h[v] == h[prev] + 1) {
p[v] = prev;
nxt = v;
} else
impossible();
}
}
}
for (int i = 0; i < n; ++i) {
if (i == root)
assert(p[i] == -1);
else {
printf("%d %d\n", i + 1, p[i] + 1);
}
}
return 0;
}