CF490F Treeland Tour (线段树合并)
线段树合并
考虑维护 \(lis_{u,i}/lds_{u,i}\) 当前子树 \(u\) 中以 \(i\) 结尾的上升子序列/下降子序列。考虑转移,实质上就是合并每个儿子的信息,用线段树合并即可。
考虑如何统计答案,当枚举到儿子 \(v\) 时,维护答案分两种情况:
选 \(u\) 点,那么就是前面的 \(ans=\max(ans, maxlis+vlds+1)\) 和 \(ans=\max(ans, maxlis+vlds+1)\)
不选 \(u\) 点,考虑在合并操作中统计答案。在合并时会枚举断点 \(mid\),此时 \(ans=\max(ans, lis_{p1_{ls}}+lds_{p2_{rs}})\) 或 \(ans=\max(ans, lis_{p2_{ls}}+lds_{p1_{rs}})\)。
这部分可能就我看得懂
复杂度 \(O(n\log n)\)。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define pb push_back
typedef long long i64;
const int N = 6010;
int n, cnt, num, ans;
int a[N];
int h[N];
struct node {
int to, nxt;
} e[N << 1];
void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = h[u];
h[u] = cnt;
}
int ls[4000010], rs[1000010];
struct LS {
int tot;
int t[4000010];
void pushup(int u, int ls, int rs) {t[u] = std::max(t[ls], t[rs]);}
void ins(int &u, int l, int r, int x, int y) {
if(!u) u = ++tot;
if(l == r) {
t[u] = std::max(t[u], y);
return;
}
int mid = (l + r) >> 1;
if(x <= mid) ins(ls[u], l, mid, x, y);
else ins(rs[u], mid + 1, r, x, y);
pushup(u, ls[u], rs[u]);
}
int query(int u, int l, int r, int L, int R) {
if(R < L) return 0;
if(L <= l && r <= R) {
return t[u];
}
int mid = (l + r) >> 1, ret = 0;
if(L <= mid) ret = std::max(ret, query(ls[u], l, mid, L, R));
if(R > mid) ret = std::max(ret, query(rs[u], mid + 1, r, L, R));
return ret;
}
} lis, lds;
void mg(int p1, int p2, int l, int r) {
if(l == r) {
lis.t[p1] = std::max(lis.t[p1], lis.t[p2]);
lds.t[p1] = std::max(lds.t[p1], lds.t[p2]);
return;
}
int mid = (l + r) >> 1;
int lv = ls[p1], rv = rs[p2];
ans = std::max(ans, lis.t[lv] + lds.t[rv]);
lv = ls[p2], rv = rs[p1];
ans = std::max(ans, lis.t[lv] + lds.t[rv]);
if(ls[p1] && ls[p2]) mg(ls[p1], ls[p2], l, mid);
else if(ls[p2]) ls[p1] = ls[p2];
if(rs[p1] && rs[p2]) mg(rs[p1], rs[p2], mid + 1, r);
else if(rs[p2]) rs[p1] = rs[p2];
lis.pushup(p1, ls[p1], rs[p1]);
lds.pushup(p1, ls[p1], rs[p1]);
}
void dfs(int u, int fa) {
int mxlis = 0, mxlds = 0;
for(int i = h[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs(v, u);
int vlis = lis.query(v, 1, num, 1, a[u] - 1), vlds = lds.query(v, 1, num, a[u] + 1, num);
ans = std::max(ans, std::max(mxlis + vlds + 1, vlis + mxlds + 1));
mxlis = std::max(mxlis, vlis);
mxlds = std::max(mxlds, vlds);
mg(u, v, 1, n);
}
lis.ins(u, 1, num, a[u], mxlis + 1);
lds.ins(u, 1, num, a[u], mxlds + 1);
}
int b[N];
void Solve() {
std::cin >> n;
lis.tot = lds.tot = n;
for(int i = 1; i <= n; i++) {
std::cin >> a[i];
b[++num] = a[i];
}
std::sort(b + 1, b + n + 1);
num = std::unique(b + 1, b + num + 1) - b - 1;
for(int i = 1; i <= n; i++) {
a[i] = std::lower_bound(b + 1, b + num + 1, a[i]) - b;
}
for(int i = 1; i < n; i++) {
int u, v;
std::cin >> u >> v;
add(u, v), add(v, u);
}
dfs(1, 0);
std::cout << ans << "\n";
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
Solve();
return 0;
}