Bzoj 2286 & Luogu P2495 消耗战(LCA+虚树+欧拉序)
题面
题解
很容易想到$O(nk)$的树形$dp$吧,设$f[i]$表示处理完这$i$颗子树的最小花费,同时再设一个$mi[i]$表示$i$到根节点$1$路径上的距离最小值。于是有:
$ f[i]=\sum min(f[son[i]], mi[son[i]]) $
这样就有$40$分了。
考虑优化:这里可以用虚树来优化,先把所有点按照$DFS$序进行排序,然后将相邻两个点的$LCA$以及$1$号点加入进$LCA$,然后虚树就构好了,考虑欧拉序的特殊性质,所以再还原出欧拉序,在上面做$dp$就好了。(xgzc告诉我可以再$dfs$一遍,但我不想写了,欧拉序多好啊)
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using std::min; using std::max;
using std::swap; using std::sort;
typedef long long ll;
template<typename T>
void read(T &x) {
int flag = 1; x = 0; char ch = getchar();
while(ch < '0' || ch > '9') { if(ch == '-') flag = -flag; ch = getchar(); }
while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
x *= flag;
}
const int N = 2.5e5 + 10, Inf = 1e9 + 7;
int n, m, dfin[N], dfout[N], tim;
int cnt, from[N], to[N << 1], nxt[N << 1];
int siz[N], son[N], dep[N], top[N], fa[N];
int nt[N << 1], vis[N], s[N << 1], tt;
ll mi[N], f[N], dis[N << 1];
inline void addEdge(int u, int v, ll w) {
to[++cnt] = v, dis[cnt] = w, nxt[cnt] = from[u], from[u] = cnt;
}
inline bool cmp(const int &x, const int &y) {
int k1 = x > 0 ? dfin[x] : dfout[-x], k2 = y > 0 ? dfin[y] : dfout[-y];
return k1 < k2;
}
void dfs(int u) {
siz[u] = 1, dfin[u] = ++tim, dep[u] = dep[fa[u]] + 1;
for(int i = from[u]; i; i = nxt[i]) {
int v = to[i]; if(v == fa[u]) continue;
mi[v] = min(mi[u], dis[i]);
fa[v] = u, dfs(v), siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
} dfout[u] = ++tim;
}
void dfs(int u, int t) {
top[u] = t; if(!son[u]) return ;
dfs(son[u], t);
for(int i = from[u]; i; i = nxt[i])
if(to[i] != son[u] && to[i] != fa[u]) dfs(to[i], to[i]);
}
int lca(int x, int y) {
int fx = top[x], fy = top[y];
while(fx != fy)
if(dep[fx] > dep[fy]) x = fa[fx], fx = top[x];
else y = fa[fy], fy = top[y];
return dep[x] < dep[y] ? x : y;
}
int main () {
read(n);
for(int i = 1; i < n; ++i) {
int u, v; ll w; read(u), read(v), read(w);
addEdge(u, v, w), addEdge(v, u, w);
}
mi[1] = Inf, dfs(1), dfs(1, 1), read(m);
for(int i = 1; i <= m; ++i) {
int tot; read(tot);
for(int j = 1; j <= tot; ++j)
read(nt[j]), vis[nt[j]] = true, f[nt[j]] = mi[nt[j]];
sort(&nt[1], &nt[tot + 1], cmp);
for(int j = 1; j < tot; ++j) {
int cf = lca(nt[j], nt[j + 1]);
if(!vis[cf]) nt[++tot] = cf, vis[cf] = true;
}
int tmp = tot;
for(int j = 1; j <= tmp; ++j)
nt[++tot] = -nt[j];
if(!vis[1]) nt[++tot] = 1, nt[++tot] = -1;
sort(&nt[1], &nt[tot + 1], cmp);
for(int j = 1; j <= tot; ++j)
if(nt[j] > 0) s[++tt] = nt[j];
else {
int now = s[tt--];
if(now != 1) { int fat = s[tt]; f[fat] += min(f[now], mi[now]); }
else printf("%lld\n", f[1]);
f[now] = vis[now] = 0;
}
}
return 0;
}