[解题报告] 「BJOI2017」树的难题(点分治 + 线段树/单调队列)
题意
一个点数为 \(n\) 的树,每个点有颜色 \(c_i\),每个颜色有权值 \(v_i\)。
一条路径的权值定义为该条路径上各同色连通块的权值之和。
例如,一条颜色序列为 1 2 2 1
的路径,其权值为 \(v_1 + v_2 + v_1\)。
求长度在 \([L,R]\) 之前的路径的权值最大值。
\(n \le 2 \times 10^5, c \le n\)
解法
解法一 线段树
树上和距离有关的问题首先考虑点分治。对于每个节点 u,
把它的儿子按照颜色排序,让颜色相同的儿子放在一起。
点分治时开两棵线段树,分别记录与当前颜色相同 / 不同的最大值,遍历完该种颜色后把贡献加到第二棵线段树上。
时间复杂度为 \(O(n\log^2 n)\)。
解法二 按秩合并单调队列
考虑直接把路径扣出来,然后用单调队列处理(不同颜色分开处理,相同颜色的就按照子树分开处理)。
但是由于这里单调队列会有个初始化的复杂度,就是把取值区间的左端点从 \(maxdis\) 移动到 \(L\)(其中 \(maxdis\) 表示之前遍历过的子树的最大深度),这样复杂度可以被卡到 \(O(n^2)\)。
有一个叫 “按秩合并单调队列” 的做法。
就是把扣出来的路径排序,不同颜色之间按照该颜色中路径长度最大值从小到大排序,相同颜色的路径也按照长度从小到大排序,这样单调队列的复杂度就是对的(可以看做把之前的路径长度带来的复杂度算到自己身上,然后初始化的总复杂度就是 \(O(sz)\) 的),复杂度是 \(O(n \log n)\) 的。
解法三 带回溯单调队列
当然如果像我一样蠢,可以写一个带回溯的单调队列,然后调两个晚上。
其实是因为我第一次写线段树做法的时候脑子抽了,点分治的时候把每个点的儿子 \(reverse\) 之后再做了一遍……然后它就 T 了……然后我看一眼讨论区,发现 “单调队列” 四个字,然后就想都没想就直接 \(Dfs\) 的时候用单调队列维护……然后它 WA 了……然后我发现好像要回溯……然后我写了个回溯……然后它 T 了(原因和解法二中的差不多)……然后我就调了两个晚上……
代码
线段树
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>
#define pb push_back
#define mkp make_pair
#define fi first
#define se second
using namespace std;
const int _ = 2e5 + 7;
const int __ = 2e7 + 7;
const int inf = 2e9 + 7;
int n;
struct SGT {
#define mid ((l + r) >> 1)
int rt, tot, maxn[__], ls[__], rs[__];
void clear() { rt = tot = 0; }
void Modify(int &k, int l, int r, int x, int w) {
if (!k) k = ++tot, maxn[k] = -inf, ls[k] = rs[k] = 0;
if (l == r) return (void)(maxn[k] = max(maxn[k], w));
if (x <= mid) Modify(ls[k], l, mid, x, w);
else Modify(rs[k], mid + 1, r, x, w);
maxn[k] = max(ls[k] ? maxn[ls[k]] : -inf, rs[k] ? maxn[rs[k]] : -inf);
}
void Modify(int x, int w) { Modify(rt, 1, n, x, w); }
int Query(int &k, int l, int r, int x, int y) {
if (!k) k = ++tot, maxn[k] = -inf, ls[k] = rs[k] = 0;
if (l >= x and r <= y) return maxn[k];
int t1 = -inf, t2 = -inf;
if (x <= mid) t1 = Query(ls[k], l, mid, x, y);
if (y > mid) t2 = Query(rs[k], mid + 1, r, x, y);
return max(t1, t2);
}
int Query(int l, int r) {
if (l > r) return -inf;
maxn[0] = inf;
return Query(rt, 1, n, l, r);
}
#undef mid
} S, T;
int m, L, R, val[_], ans = -inf;
vector<pair<int, int>> to[_];
int gi() {
int x = 0; bool f = 0; char c = getchar();
while (!isdigit(c) and c != '-') c = getchar();
if (c == '-') f = 1, c = getchar();
while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
return f ? -x : x;
}
void Init() {
n = gi(), m = gi(), L = gi(), R = gi();
for (int i = 1; i <= m; ++i) val[i] = gi();
for (int i = 1, x, y, c; i < n; ++i) {
x = gi(), y = gi(), c = gi();
to[x].pb(mkp(c, y));
to[y].pb(mkp(c, x));
}
for (int i = 1; i <= n; ++i) sort(to[i].begin(), to[i].end());
}
int rt, minx, numV, sz[_], top;
pair<int, int> box[_];
bool vis[_];
void GetSz(int u, int fa) {
++numV, sz[u] = 1;
for (auto x: to[u])
if (!vis[x.se] and x.se != fa) GetSz(x.se, u), sz[u] += sz[x.se];
}
void FindRt(int u, int fa) {
int maxsz = 0;
for (auto x: to[u])
if (!vis[x.se] and x.se != fa) FindRt(x.se, u), maxsz = max(maxsz, sz[x.se]);
maxsz = max(maxsz, numV - sz[u]);
if (maxsz < minx) rt = u, minx = maxsz;
}
void Stat(int u, int fa, int w, int lst, int t, int dis) {
if (dis > R) return;
ans = max(ans, w + S.Query(max(1, L - dis), min(n, R - dis)));
ans = max(ans, w - t + T.Query(max(1, L - dis), min(n, R - dis)));
if (dis >= L) ans = max(ans, w);
for (auto x: to[u])
if (!vis[x.se] and x.se != fa) Stat(x.se, u, w + (x.fi != lst) * val[x.fi], x.fi, t, dis + 1);
}
void Cont(int u, int fa, int w, int lst, int dis) {
if (dis > R) return;
T.Modify(dis, w), box[++top] = mkp(dis, w);
for (auto x: to[u])
if (!vis[x.se] and x.se != fa) Cont(x.se, u, w + (x.fi != lst) * val[x.fi], x.fi, dis + 1);
}
void Calc() {
int lst = 0; top = 0;
for (auto x: to[rt]) {
int v = x.se;
if (vis[v]) continue;
if (x.fi != lst) {
while (top) S.Modify(box[top].fi, box[top].se), --top;
T.clear(), lst = x.fi;
}
Stat(v, rt, val[x.fi], x.fi, val[x.fi], 1);
Cont(v, rt, val[x.fi], x.fi, 1);
}
S.clear(), T.clear();
}
void Work(int u) {
rt = 0, minx = inf, numV = 0;
GetSz(u, 0);
FindRt(u, 0);
int tmp = rt; vis[rt] = 1;
for (auto x: to[tmp])
if (!vis[x.se]) Work(x.se);
rt = tmp, vis[rt] = 0;
Calc();
}
int main() {
Init();
Work(1);
cout << ans << endl;
return 0;
}
带回溯单调队列
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#define pb push_back
#define mkp make_pair
#define fi first
#define se second
using namespace std;
const int _ = 2e5 + 7;
const int __ = 2e6 + 7;
const int inf = 2e9 + 7;
int n;
struct SGT {
#define mid ((l + r) >> 1)
int lc, rc, maxn[_], box[_], top, val, q[_], t1, t2, cnt, tot;
pair<int, int> a[__];
struct NODE { int lc, rc, t1, t2, st; } rev[_];
void Init() { memset(maxn, -0x3f, sizeof maxn); lc = rc = n + 1, t1 = 1, t2 = 0; }
void Recall() {
t1 = rev[cnt].t1, t2 = rev[cnt].t2, lc = rev[cnt].lc, rc = rev[cnt].rc;
while (tot > rev[cnt].st) q[a[tot].fi] = a[tot].se, --tot;
--cnt;
}
void clear() {
int x = 0;
while (top) x = max(x, box[top]), maxn[box[top--]] = -inf;
while (cnt and lc <= x) Recall();
}
void reset(int p) { lc = rc = p, t1 = 1, t2 = 0, cnt = 0; }
void Modify(int x, int w) {
box[++top] = x;
maxn[x] = max(maxn[x], w);
while (cnt and lc <= x) Recall();
}
int Query(int l, int r) {
rev[++cnt] = { lc, rc, t1, t2, tot };
while (lc > l) {
--lc;
while (t2 >= t1 and maxn[q[t2]] <= maxn[lc]) {
a[++tot] = mkp(t2, q[t2]);
--t2;
}
q[++t2] = lc;
}
rc = min(rc, r);
while (t1 <= t2 and (q[t1] > rc or q[t1] < lc)) a[++tot] = mkp(t1, q[t1]), ++t1;
return t1 > t2 ? -inf : maxn[q[t1]];
}
#undef mid
} S, T;
int m, L, R, val[_], ans = -inf;
vector<pair<int, int>> to[_];
int gi() {
int x = 0; bool f = 0; char c = getchar();
while (!isdigit(c) and c != '-') c = getchar();
if (c == '-') f = 1, c = getchar();
while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
return f ? -x : x;
}
void Init() {
n = gi(), m = gi(), L = gi(), R = gi();
for (int i = 1; i <= m; ++i) val[i] = gi();
for (int i = 1, x, y, c; i < n; ++i) {
x = gi(), y = gi(), c = gi();
to[x].pb(mkp(c, y));
to[y].pb(mkp(c, x));
}
for (int i = 1; i <= n; ++i) sort(to[i].begin(), to[i].end());
S.Init(), T.Init();
}
int rt, minx, numV, sz[_], top, dep[_];
pair<int, int> box[_];
bool vis[_];
void GetSz(int u, int fa) {
++numV, sz[u] = 1, dep[u] = 1;
for (auto x: to[u])
if (!vis[x.se] and x.se != fa) GetSz(x.se, u), sz[u] += sz[x.se], dep[u] = max(dep[u], dep[x.se] + 1);
}
void FindRt(int u, int fa) {
int maxsz = 0;
for (auto x: to[u])
if (!vis[x.se] and x.se != fa) FindRt(x.se, u), maxsz = max(maxsz, sz[x.se]);
maxsz = max(maxsz, numV - sz[u]);
if (maxsz < minx) rt = u, minx = maxsz;
}
void Stat(int u, int fa, int w, int lst, int t, int dis) {
if (dis > R) return;
ans = max(ans, w + S.Query(max(1, L - dis), min(n, R - dis)));
ans = max(ans, w - t + T.Query(max(1, L - dis), min(n, R - dis)));
if (dis >= L) ans = max(ans, w);
for (auto x: to[u])
if (!vis[x.se] and x.se != fa) Stat(x.se, u, w + (x.fi != lst) * val[x.fi], x.fi, t, dis + 1);
S.Recall();
T.Recall();
}
void Cont(int u, int fa, int w, int lst, int dis) {
if (dis > R) return;
T.Modify(dis, w);
box[++top] = mkp(dis, w);
for (auto x: to[u])
if (!vis[x.se] and x.se != fa) Cont(x.se, u, w + (x.fi != lst) * val[x.fi], x.fi, dis + 1);
}
void Calc() {
int lst = 0; top = 0;
S.reset(dep[rt] + 1), T.reset(dep[rt] + 1);
while (S.lc > L) S.Query(S.lc - 1, R);
while (T.lc > L) T.Query(T.lc - 1, R);
for (auto x: to[rt]) {
int v = x.se;
if (vis[v]) continue;
if (x.fi != lst) {
while (top) S.Modify(box[top].fi, box[top].se), --top;
T.clear();
lst = x.fi;
}
while (S.lc > L) S.Query(S.lc - 1, R);
while (T.lc > L) T.Query(T.lc - 1, R);
Stat(v, rt, val[x.fi], x.fi, val[x.fi], 1);
Cont(v, rt, val[x.fi], x.fi, 1);
}
S.clear();
T.clear();
}
void Work(int u) {
rt = 0, minx = inf, numV = 0;
GetSz(u, 0);
FindRt(u, 0);
int tmp = rt; vis[rt] = 1;
for (auto x: to[tmp])
if (!vis[x.se]) Work(x.se);
rt = tmp, vis[rt] = 0;
Calc();
}
int main() {
Init();
Work(1);
cout << ans << endl;
return 0;
}