BZOJ 2286 消耗战 - 虚树 + 树型dp
题目大意:
每次给出k个特殊点,回答将这些特殊点与根节点断开至少需要多少代价。
题目分析:
虚树
入门 + 树型dp:
刚刚学习完虚树(好文),就来这道入门题签个到。
虚树就是将树中的一些关键点提取出来,在不改变父子关系的条件下用$O(mlog n) \(组成一颗新树(m特殊点数,n总数),大小为\)O(m)$,以便降低后续dp等的复杂度。
建好虚树过后就可以进行普通的dp了(mn[u]表示原图中u到根节点的最短边长):
\[dp[u] = mn[u] (u是特殊点)
\]
\[dp[u] = min(mn[u], \sum{dp[son[u]]}) (u不是特殊点)
\]
此题就当做是虚树模板了。
注意每一次重建虚树时不用将1~n的边信息全部清空,不然你会见识到clear的惊人速度(T飞)。
code
#include<bits/stdc++.h>
using namespace std;
#define maxn 250050
#define oo 0x3f3f3f3f
typedef long long ll;
typedef pair<int, ll> pil;
namespace IO{
inline int read(){
int i = 0, f = 1; char ch = getchar();
for(; (ch < '0' || ch > '9') && ch != '-'; ch = getchar());
if(ch == '-') f = -1, ch = getchar();
for(; ch >= '0' && ch <= '9'; ch = getchar()) i = (i << 3) + (i << 1) + (ch - '0');
return i * f;
}
inline void wr(ll x){
if(x < 0) x = -x, putchar('-');
if(x > 9) wr(x / 10);
putchar(x % 10 + '0');
}
}using namespace IO;
int n, m;
vector<pil> g[maxn];
vector<int> vg[maxn];
ll dp[maxn], mn[maxn];
int dfn[maxn], clk, dep[maxn], vir[maxn], virCnt, par[maxn], rt;
bool key[maxn];
namespace LCA{
int pos[maxn], top[maxn], son[maxn], sze[maxn], tot, fa[maxn];
inline void dfs1(int u, int f){
dfn[u] = ++clk;
dep[u] = dep[f] + 1;
fa[u] = f;
sze[u] = 1;
for(int i = g[u].size() - 1; i >= 0; i--){
int v = g[u][i].first;
if(v == f) continue;
mn[v] = min(mn[u], g[u][i].second);
dfs1(v, u);
sze[u] += sze[v];
if(sze[v] > sze[son[u]] || !son[u]) son[u] = v;
}
}
inline void dfs2(int u, int f){
if(son[u]){
pos[son[u]] = ++tot;
top[son[u]] = top[u];
dfs2(son[u], u);
}
for(int i = g[u].size() - 1; i >= 0; i--){
int v = g[u][i].first;
if(v == f || v == son[u]) continue;
pos[v] = ++tot;
top[v] = v;
dfs2(v, u);
}
}
inline void splitTree(){
dfs1(1, 0);
pos[tot = 1] = top[1] = 1;
dfs2(1, 0);
}
inline int getLca(int u, int v){
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u, v);
u = fa[top[u]];
}
return dep[u] < dep[v] ? u : v;
}
}
inline bool cmp(int u, int v){
return dfn[u] < dfn[v];
}
inline void buildVir(){
static int stk[maxn], top;
top = 0;
sort(vir + 1, vir + virCnt + 1, cmp);
int oriSze = virCnt;
for(int i = 1; i <= oriSze; i++){
int u = vir[i];
if(!top){
stk[++top] = u;
par[u] = 0;
continue;
}
int lca = LCA::getLca(stk[top], u);
while(dep[lca] < dep[stk[top]]){
if(dep[stk[top - 1]] < dep[lca]) par[stk[top]] = lca;
--top;
}
if(lca != stk[top]){
vir[++virCnt] = lca;
par[lca] = stk[top];
stk[++top] = lca;
}
par[u] = lca;
stk[++top] = u;
}
for(int i = 1; i <= virCnt; i++) vg[vir[i]].clear();
for(int i = 1; i <= virCnt; i++){
int u = vir[i];
key[u] = ((i <= oriSze) ? 1 : 0);
if(par[u]) vg[par[u]].push_back(u);
}
sort(vir + 1, vir + virCnt + 1, cmp);
}
inline void DP(int u){
// cout<<u<<"!";
ll ret = 0;
for(int i = vg[u].size() - 1; i >= 0; i--){
int v = vg[u][i];
DP(v);
ret += dp[v];
}
if(key[u]) dp[u] = mn[u];
else dp[u] = min(mn[u], ret);
}
inline void solve(){
buildVir();
DP(vir[1]);
wr(dp[vir[1]]);
putchar('\n');
}
int main(){
freopen("h.in", "r", stdin);
n = read();
for(int i = 1; i < n; i++){
int x = read(), y = read();
ll c = 1ll * read();
g[x].push_back(pil(y, c));
g[y].push_back(pil(x, c));
}
memset(mn, oo, sizeof mn);
LCA::splitTree();
m = read();
for(int i = 1; i <= m; i++){
int k = read();
virCnt = 0;
for(int j = 1; j <= k; j++)
vir[++virCnt] = read();
solve();
}
return 0;
}