NOI2022D1T3 树上邻域数点
有趣的题,口胡起来很快乐,但码量不小外加不能操作 grader 导致调试很困难,写+调花了五个小时。样例好像也很水,对着数据调才过。。。
特殊性质 \(\text{A}\) 是一条链,这可以直接猫树分治做到 \(O(n \log n)\) 预处理,\(1\) 次操作查询。
特殊性质 \(\text{B}\) 是每次询问的点都是 \(1\)。跟深度有关,可以考虑长链剖分。把轻子树暴力归并到重子树之后,重子树超出轻子树的部分可以直接打一个懒标记,就可以做到 \(O(n)\) 预处理,\(0\) 次操作查询。
知道复杂度认真思考之后尝试用类似树链剖分的东西乱搞出一个做法。首先还是自然想到分子树外的点和子树内的点。但是一般这种树链剖分对于一个点,问题的规模是不能超过子树大小级别的,否则复杂度很可能会炸。但是询问的 \(d\) 可能会超过 \(suze_u\),而对于不同的 \(d\),\(u\) 的答案也不相同,所以几乎不能对 \(u\) 维护很大的 \(d\) 的答案。
但是可以发现如果 \(d\) 大于 \(size_u\),那么这个 \(d\) 很浪费:\(u\) 子树内的点距离远远小于 \(d\)。所以尝试将询问做一些修改:如果 \(d\) 较大,那么可以将 \(u\) 换成 \(fa_u\),并且将 \(d\) 减 \(1\),这样 \(u\) 子树外覆盖到的点不变,而子树内的点仍然被完全覆盖。
不断这样操作,不可行时说明向上一步会导致 \(u\) 子树内的点覆盖不完全。记 \(h_u\) 表示 \(u\) 子树内的点到 \(u\) 的最大距离,那么就有 \(d \le h_u + 1 \le size_u\),这样之后就能将 \(d\) 控制在 \(size_u\) 之内了。
然后考虑在每一条重链上进行链分治。分治的时候记录分治区间和区间内最左边、最右边的点连向区间外面的点中每个深度的答案。分治的时候直接把左半区间和最左边连出去的每个深度的答案跑一遍上面的长链剖分贡献到右半区间,右半区间同理贡献到左半区间。由于已经将询问的 \(d\) 控制在 \(size_u\) 以内,所以只要保留区间内轻子树 \(size\) 和的长度就行。用全局平衡二叉树的结构即可 \(O(n \log n)\) 得到每个重链链顶所需要的值。
然后计算答案还需要算子树内的答案。刚刚的全局平衡二叉树只能保留轻子树的 \(size\) 和,也就是说如果要直接记录每个点所在重链中自身下方的点对询问的贡献,那么复杂度仍然爆炸。仍然尝试把这一部分的复杂度均摊到轻子树上面:可以找到从这个点向下的第一个点,使得这个点的轻子树没有被这次询问完全覆盖。这由于上面将 \(d\) 在 \(h_u + 1\) 之内,是几乎可以做到的。唯一不行的就是 \(d = h_u\) 或 \(d = h_u + 1\),而这时 \(d\) 只超出了 \(1\),所以可以在分治过程中给每个分治节点长度加上 \(1\)。
找到了这个点之后想让复杂度均摊到这个点的轻子树上面。于是可以模仿特殊性质 \(\text{A}\) 的做法:设找到的第一个轻子树不被完全覆盖的点 \(v\),再找到 \(u, v\) 被分治开的区间 \([l, mid, r]\),尝试在这个区间解决询问。
\(v\) 轻子树没有被完全覆盖,所以可以直接把刚刚的分治过程搬过来,只不过深度还要求完全覆盖右边的轻子树,这样复杂度仍然正确,并且解决了右半区间的问题。又可以发现 \(u\) 到 \(v\) 之前的轻子树被完全包含,所以对于左半区间以 \(u\) 为根的答案,可以直接改成以 \(mid\) 为根,\(d\) 加上 \(mid - u\) 的答案。这样 \(u\) 左边被覆盖的点集相同,而 \([u, mid]\) 中轻子树仍然被完全包含。于是左半边也是直接将上面的分治过程搬过来,稍微算一下跑那个长剖的深度至少要是多少即可。询问直接将左右两边拼起来。
我的代码为了方便,没有用全局平衡二叉树,而是用了长链剖分下对每条链进行分治,分治过程直接取中点即可。由于是深度相关,每条长链只会作为轻子树在上面那条长链的分治过程中算 \(O(\log n)\) 次,所以总复杂度仍然是 \(O(n \log n)\)。
Code
#include "count.h"
#include <vector>
#include <algorithm>
using namespace std;
const int N = 200005;
vector<int> to[N];
int h[N], lh[N], son[N];
int fa[N][18], top[N], mp[N], tot;
info e[N];
vector<info> dp[N], lson[N];
info mr(info x, info y) {
if (isempty(x))
return y;
if (isempty(y))
return x;
return MR(x, y);
}
info mc(info x, info y) {
if (isempty(x))
return y;
if (isempty(y))
return x;
return MC(x, y);
}
info getv(vector<info> &a, int p) { return a.empty() ? emptyinfo : a[min((int)a.size() - 1, p)]; }
vector<info> merge(vector<int> a, vector<vector<info>> b, int m = -1) {
static info v[N];
static int l[N], r[N];
int n = a.size();
if (m == -1)
m = 1 << 30;
if (n == 1 && b[0].size() <= 1)
return {emptyinfo};
vector<info> f{emptyinfo};
v[0] = emptyinfo;
for (int i = 1; i != n; i++)
l[i] = i - 1, r[i] = i + 1, v[i] = e[max(a[i - 1], a[i])];
r[0] = 1, l[0] = n;
r[n] = 0, l[n] = n - 1;
for (int i = 0; i != n; i++)
if (b[i].empty())
b[i].push_back(emptyinfo);
int bg = 0;
auto del = [&](int x) {
l[r[x]] = l[x], r[l[x]] = r[x];
if (x == bg)
bg = r[x];
v[x] = mr(v[x], b[x].back());
if (x)
v[r[x]] = mc(v[x], v[r[x]]);
else
v[r[x]] = mr(v[r[x]], v[x]);
};
if(b[0].size() <= 1)
del(0);
for (int len = 1; len <= m; len++) {
info x = emptyinfo;
for (int i = bg; i <= len && i < n; i = r[i]) {
if (len - i == (int)b[i].size() - 1 && r[i] <= len && i != n - 1) {
del(i);
continue;
}
if (l[i])
x = mc(x, v[i]);
else
x = mr(v[i], x);
x = mr(x, getv(b[i], len - i));
if (len - i == (int)b[i].size() - 1 && i != n - 1)
del(i);
}
f.push_back(x);
if (bg == n - 1 && len - (n - 1) >= (int)b[n - 1].size() - 1)
return f;
}
return f;
}
vector<info> merge(vector<info> a, vector<info> b, bool op) {
if (a.size() > b.size())
swap(a, b);
if (op) {
for (int i = 0; i != (int)b.size(); i++)
b[i] = mr(b[i], getv(a, i));
return b;
}
for (int i = 0; i != (int)a.size(); i++)
a[i] = mr(a[i], getv(b, i));
return a;
}
void dfs1(int i) {
for (int j : to[i])
dfs1(j);
sort(to[i].begin(), to[i].end(), [&](int i, int j){ return h[i] < h[j]; });
if (!to[i].empty()) {
son[i] = to[i].back();
h[i] = h[son[i]] + 1;
to[i].pop_back();
}
for (int j : to[i])
lh[i] = max(lh[i], h[j]);
}
void dfs2(int i) {
for (int x = i; x; x = son[x])
for (int y : to[x]) {
dfs2(y);
vector<info> a = dp[y];
for (info &x : a)
x = mr(e[y], x);
a.insert(a.begin(), emptyinfo);
lson[x] = merge(lson[x], a, 1);
}
vector<int> a;
vector<vector<info>> b;
for (int x = i; x; x = son[x])
a.push_back(x), b.push_back(lson[x]), top[x] = i;
dp[i] = merge(a, b);
}
void dfs3(int i, vector<info> v);
vector<info> tmp[N];
struct sgt{
vector<int> rh, a;
vector<vector<info>> vl, vr;
void init(int n) { vl.resize(n * 4), vr.resize(n * 4), rh.resize(n * 4), a.resize(n); }
void build(int i, int l, int r, vector<info> fl, vector<info> fr) {
static vector<info> val[N];
if (l == r) {
int x = a[l], len = lh[x] + 3;
if (len < (int)fl.size())
fl.resize(len);
if (len < (int)fr.size())
fr.resize(len);
vector<info> f = merge(fl, fr, 1);
vl[i] = f;
if ((int)f.size() < len) {
if (f.empty())
f.resize(len, emptyinfo);
else
f.resize(len, f.back());
}
for (int j = 0; j != (int)to[x].size(); j++) {
int y = to[x][j];
tmp[y] = dp[y];
for (info &z : tmp[y])
z = mr(e[y], z);
tmp[y].insert(tmp[y].begin(), emptyinfo);
val[j + 1] = merge(val[j], tmp[y], 1);
}
for (int j = to[x].size() - 1; j != -1; j--) {
int y = to[x][j];
vector<info> g = merge(f, val[j], 1), tp = tmp[y];
for (info &z : g)
z = mr(e[y], z);
g.insert(g.begin(), emptyinfo);
tmp[y] = g, f = merge(f, tp, 0);
}
return;
}
int mid = (l + r) >> 1;
for (int j = mid + 1; j <= r; j++)
rh[i] = max(rh[i], j + lh[a[j]]);
rh[i] += 1;
int ll = rh[i] - l + mid - l;
int rl = rh[i] - mid;
for (int j = l; j <= mid; j++)
rl = max(rl, lh[a[j]] - (mid - j) + 1);
int tp = ll - (mid - l);
vector<info> ttl = fl, ttr = fr;
if (tp < (int)fl.size())
fl.resize(tp + 1);
tp = max(0, rl - (r - mid) + 1);
if (tp < (int)fr.size())
fr.resize(tp + 1);
vector<int> al(mid - l + 1), ar(r - mid);
vector<vector<info>> bl(mid - l + 1), br(r - mid);
for (int j = mid; j >= l; j--) {
al[mid - j] = a[j];
bl[mid - j] = lson[a[j]];
}
bl[mid - l] = merge(bl[mid - l], fl, 1);
for (int j = mid + 1; j <= r; j++) {
ar[j - mid - 1] = a[j];
br[j - mid - 1] = lson[a[j]];
}
br[r - mid - 1] = merge(br[r - mid - 1], fr, 1);
vl[i] = merge(al, bl, ll), vr[i] = merge(ar, br, rl);
for (info &x : vl[i])
x = mr(e[a[mid + 1]], x);
vl[i].insert(vl[i].begin(), emptyinfo);
vector<info> tr = vr[i];
for (info &x : tr)
x = mr(e[a[mid + 1]], x);
tr.insert(tr.begin(), emptyinfo);
build(i << 1, l, mid, ttl, tr);
build(i << 1 | 1, mid + 1, r, vl[i], ttr);
}
pair<info, bool> query(int i, int l, int r, int x, int d) {
if (l == r) {
if (d <= lh[a[x]] + 1)
return {mr(getv(vl[i], d), getv(lson[a[x]], d)), 1};
return {emptyinfo, 0};
}
int mid = (l + r) >> 1;
if (x > mid)
return query(i << 1 | 1, mid + 1, r, x, d);
pair<info, bool> tp = query(i << 1, l, mid, x, d);
if (tp.second)
return tp;
if (x + d > rh[i])
return {emptyinfo, 0};
return {mr(getv(vl[i], d + mid - x + 1), getv(vr[i], x + d - mid - 1)), 1};
}
info query(int x, int d) { return query(1, 0, a.size() - 1, x, d).first; }
}seg[N];
void dfs3(int i, vector<info> v) {
int l = h[i] + 1;
mp[i] = ++tot;
seg[tot].init(l);
for (int x = i; x; x = son[x])
seg[tot].a[h[i] - h[x]] = x;
seg[tot].build(1, 0, h[i], v, vector<info>{emptyinfo});
for (int x = i; x; x = son[x])
for (int y : to[x])
dfs3(y, tmp[y]);
}
void init(int t, int n, int q, vector<int> ff, vector<info> ee, int limit) {
for (int i = 2; i <= n; i++) {
fa[i][0] = ff[i - 2], e[i] = ee[i - 2];
to[ff[i - 2]].push_back(i);
for (int j = 1; fa[fa[i][j - 1]][j - 1]; j++)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
dfs1(1), dfs2(1);
dfs3(1, vector<info>{emptyinfo});
}
info ask(int u, int d) {
for (int i = 17; i != -1; i--) {
int j = fa[u][i];
if (!j)
continue;
if (h[j] + 1 < d - (1 << i)) {
u = fa[u][i];
d -= 1 << i;
}
}
if (fa[u][0] && h[u] + 1 < d)
d--, u = fa[u][0];
if (u == 1)
return getv(dp[1], d);
return seg[mp[top[u]]].query(h[top[u]] - h[u], d);
}