P2495 [SDOI2011] 消耗战
题意
给定一棵有边权的无根树。
\(q\) 次询问,每次询问 \(k\) 个点。
求断边使得根节点 \(1\) 与 \(k\) 个点不连通的最小边权。
Sol
虚树。
\(n ^ 2\) dp 是 trivial 的。
考虑优化。注意到其中很多点都是无用的。
考虑保留有效点。
不难发现,有效点集为询问点两两 \(lca\) 的集合。
对于两两 \(lca\),我们可以按照 \(dfn\) 排序,变为相邻两两的 \(lca\)。
总点数 \(cnt = n + n / 2 + n / 4 + ... = 2 * n\)。
正确性显然。
问题变为:如何构造这棵树?
维护一个单调栈,维护当前最右端的这条链。
考虑当前节点与栈顶的 \(lca\) 是否为栈顶,也就是判断当前点是否在链上。
否则考虑删除,并加入当前点与栈顶的 \(lca\)。把删掉的点全部连起来就行了。
回到这道题。
使用虚树,将总 \(dp\) 的点数控制在了 \(n\) 的级别。
根据题意,需要维护链上的最小值,不带修。这是 trivial 的,使用 ST 表轻松解决。
Code
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <cmath>
#include <bitset>
#include <vector>
#define int long long
#define pii pair <int, int>
using namespace std;
#ifdef ONLINE_JUDGE
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;
#endif
int read() {
int p = 0, flg = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') flg = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
p = p * 10 + c - '0';
c = getchar();
}
return p * flg;
}
void write(int x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x > 9) {
write(x / 10);
}
putchar(x % 10 + '0');
}
#define fi first
#define se second
const int N = 2.5e5 + 5, M = 1e6 + 5, inf = 1e18;
namespace G {
array <int, N> fir;
array <int, M> nex, to, len;
int cnt;
void add(int x, int y, int z) {
cnt++;
nex[cnt] = fir[x];
to[cnt] = y;
len[cnt] = z;
fir[x] = cnt;
}
}
namespace T {
array <int, N> fir;
array <int, M> nex, to;
int cnt;
void add(int x, int y) {
/* write(x), putchar(32); */
/* write(y), puts(""); */
cnt++;
nex[cnt] = fir[x];
to[cnt] = y;
fir[x] = cnt;
}
}
namespace ST {
array <array <int, 21>, N> sT;
array <int, N> s, lg;
void init(int n) {
for (int i = 1; i <= n; i++)
lg[i] = log2(i);
for (int i = 1; i <= n; i++)
sT[i].fill(inf);
for (int i = 1; i <= n; i++)
sT[i][0] = s[i];
for (int j = 1; j < 21; j++)
for (int i = 1; i + (1 << j) - 1 <= n; i++)
sT[i][j] = min(sT[i][j - 1], sT[i + (1 << (j - 1))][j - 1]);
}
int query(int x, int y) {
if (x > y) return inf;
int l = lg[y - x + 1];
return min(sT[x][l], sT[y - (1 << l) + 1][l]);
}
}
namespace Hpt {
using G::fir; using G::nex; using G::to;
array <int, N> siz, dep, fa, son;
array <int, N> cur;
void dfs1(int x) {
siz[x] = 1;
for (int i = fir[x]; i; i = nex[i]) {
if (to[i] == fa[x]) continue;
fa[to[i]] = x;
dep[to[i]] = dep[x] + 1;
cur[to[i]] = G::len[i];
dfs1(to[i]);
siz[x] += siz[to[i]];
if (siz[to[i]] > siz[son[x]]) son[x] = to[i];
}
}
array <int, N> dfn, idx, top;
int cnt;
void dfs2(int x, int Mgn) {
cnt++;
dfn[x] = cnt;
idx[cnt] = x;
top[x] = Mgn;
if (son[x]) dfs2(son[x], Mgn);
for (int i = fir[x]; i; i = nex[i]) {
if (to[i] == fa[x] || to[i] == son[x]) continue;
dfs2(to[i], to[i]);
}
}
int lca(int x, int y) {
while (top[x] != top[y]) {
if (dfn[top[x]] < dfn[top[y]]) swap(x, y);
x = fa[top[x]];
}
if (dfn[x] > dfn[y]) swap(x, y);
return x;
}
int query(int x, int y) {
int ans = inf;
while (top[x] != top[y]) {
if (dfn[top[x]] < dfn[top[y]]) swap(x, y);
ans = min(ans, ST::query(dfn[top[x]], dfn[x]));
x = fa[top[x]];
}
if (dfn[x] > dfn[y]) swap(x, y);
ans = min(ST::query(dfn[x] + 1, dfn[y]), ans);
return ans;
}
}
namespace Vit {
vector <pii> isl;
array <int, N> stk;
int tp;
void build() {
sort(isl.begin(), isl.end());
tp = 1; stk[tp] = 1; T::fir[1] = 0;
for (auto x : isl) {
if (x.se == 1) continue;
int lcA = Hpt::lca(stk[tp], x.se);
/* write(lcA), puts("#"); */
if (lcA != stk[tp]) {
while (Hpt::dfn[stk[tp - 1]] > Hpt::dfn[lcA])
T::add(stk[tp - 1], stk[tp]), tp--;
if (stk[tp - 1] != lcA)
T::fir[lcA] = 0, T::add(lcA, stk[tp]), stk[tp] = lcA;
else
T::add(stk[tp - 1], stk[tp]), tp--;
}
T::fir[x.se] = 0, tp++, stk[tp] = x.se;
}
for (int i = 1; i < tp; i++)
T::add(stk[i], stk[i + 1]);
}
bitset <N> vis;
int dfs(int x) {
int res = 0;
for (int i = T::fir[x]; i; i = T::nex[i]) {
int len = Hpt::query(x, T::to[i]);
if (vis[T::to[i]]) {
res += len;
continue;
}
res += min(dfs(T::to[i]), len);
}
return res;
}
}
signed main() {
int n = read();
for (int i = 2; i <= n; i++) {
int x = read(), y = read(), z = read();
G::add(x, y, z), G::add(y, x, z);
}
Hpt::dfs1(1), Hpt::dfs2(1, 1);
for (int i = 1; i <= n; i++)
ST::s[i] = Hpt::cur[Hpt::idx[i]];
ST::init(n);
/* for (int i = 1; i <= n; i++) */
/* write(ST::s[i]), putchar(32); */
/* puts(""); */
/* write(Hpt::query(1, 3)), puts(""); */
/* return 0; */
int q = read();
while (q--) {
int k = read();
for (int i = 1; i <= k; i++) {
int x = read(); Vit::vis[x] = 1;
Vit::isl.push_back(make_pair(Hpt::dfn[x], x));
}
Vit::build();
write(Vit::dfs(1)), puts("");
for (auto x : Vit::isl)
Vit::vis[x.se] = 0;
Vit::isl.clear();
}
return 0;
}