[虚树][树形dp] Bzoj P2286 消耗战
Description
在一场战争中,战场由n个岛屿和n-1个桥梁组成,保证每两个岛屿间有且仅有一条路径可达。现在,我军已经侦查到敌军的总部在编号为1的岛屿,而且他们已经没有足够多的能源维系战斗,我军胜利在望。已知在其他k个岛屿上有丰富能源,为了防止敌军获取能源,我军的任务是炸毁一些桥梁,使得敌军不能到达任何能源丰富的岛屿。由于不同桥梁的材质和结构不同,所以炸毁不同的桥梁有不同的代价,我军希望在满足目标的同时使得总代价最小。
侦查部门还发现,敌军有一台神秘机器。即使我军切断所有能源之后,他们也可以用那台机器。机器产生的效果不仅仅会修复所有我军炸毁的桥梁,而且会重新随机资源分布(但可以保证的是,资源不会分布到1号岛屿上)。不过侦查部门还发现了这台机器只能够使用m次,所以我们只需要把每次任务完成即可。
Input
第一行一个整数n,代表岛屿数量。
接下来n-1行,每行三个整数u,v,w,代表u号岛屿和v号岛屿由一条代价为c的桥梁直接相连,保证1<=u,v<=n且1<=c<=100000。
第n+1行,一个整数m,代表敌方机器能使用的次数。
接下来m行,每行一个整数ki,代表第i次后,有ki个岛屿资源丰富,接下来k个整数h1,h2,…hk,表示资源丰富岛屿的编号。
Output
输出有m行,分别代表每次任务的最小代价。
Sample Input
10
1 5 13
1 9 6
2 1 19
2 4 8
2 3 91
5 6 8
7 5 4
7 8 31
10 7 9
3
2 10 6
4 5 7 8 3
3 9 4 6
1 5 13
1 9 6
2 1 19
2 4 8
2 3 91
5 6 8
7 5 4
7 8 31
10 7 9
3
2 10 6
4 5 7 8 3
3 9 4 6
Sample Output
12
32
22
32
22
HINT
对于100%的数据,2<=n<=250000,m>=1,sigma(ki)<=500000,1<=ki<=n-1
题目大意
- 给定一棵有n个节点的树,有q次询问,每次给定k个节点,问最少删掉多少条边可以使这1号节点与k个节点不连通
题解
- 首先,我们考虑一下只有一次询问的时候怎么做
- 显然,一个树形dp就可以做到O(N)
- 那么对于多次询问,我们可以把单独把这k个关键点提出来构成一棵虚树,然后将路径压缩
- 然后像上面,在虚树上做树形dp
代码
1 #include <cstdio> 2 #include <iostream> 3 #include <cstring> 4 #include <algorithm> 5 #define inf 1e60 6 #define N 250010 7 #define ll long long 8 using namespace std; 9 int mi[20]; 10 int n,m,cnt,tot,top,K,last[N],head[N],fa[N][20],h[N],id[N],deep[N],st[N]; 11 ll mn[N],f[N]; 12 struct edge{ int to,from,v; }e[N*2],E[N*2]; 13 bool cmp(int x,int y) { return id[x]<id[y]; } 14 void insert(int x,int y,int z) 15 { 16 e[++cnt].to=y,e[cnt].from=head[x],e[cnt].v=z,head[x]=cnt; 17 e[++cnt].to=x,e[cnt].from=head[y],e[cnt].v=z,head[y]=cnt; 18 } 19 void add(int x,int y) 20 { 21 if (x==y) return; 22 E[++cnt].to=y,E[cnt].from=last[x],last[x]=cnt; 23 } 24 void pre(int x) 25 { 26 id[x]=++tot; 27 for (int i=1;mi[i]<=deep[x];i++) fa[x][i]=fa[fa[x][i-1]][i-1]; 28 for (int i=head[x];i;i=e[i].from) if (e[i].to!=fa[x][0]) mn[e[i].to]=min(mn[x],(ll)e[i].v),deep[e[i].to]=deep[x]+1,fa[e[i].to][0]=x,pre(e[i].to); 29 } 30 int lca(int x,int y) 31 { 32 if (deep[x]<deep[y]) swap(x,y); 33 int t=deep[x]-deep[y]; 34 for (int i=0;mi[i]<=t;i++) if (t&mi[i]) x=fa[x][i]; 35 for (int i=19;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; 36 if (x==y) return x; 37 return fa[x][0]; 38 } 39 void dp(int x) 40 { 41 f[x]=mn[x]; ll r=0; 42 for (int i=last[x];i;i=E[i].from) dp(E[i].to),r+=f[E[i].to]; 43 last[x]=0; 44 if (r==0) f[x]=mn[x]; else f[x]=min(f[x],r); 45 } 46 void solve() 47 { 48 scanf("%d",&K),tot=1,cnt=0; 49 for (int i=1;i<=K;i++) scanf("%d",&h[i]); 50 sort(h+1,h+K+1,cmp); 51 for (int i=2;i<=K;i++) if (lca(h[tot],h[i])!=h[tot]) h[++tot]=h[i]; 52 st[top=1]=1; 53 for (int i=1;i<=tot;i++) 54 { 55 int now=h[i],LCA=lca(now,st[top]); 56 while (1) 57 { 58 if (deep[LCA]>=deep[st[top-1]]) 59 { 60 add(LCA,st[top--]); 61 if (st[top]!=LCA) st[++top]=LCA; 62 break; 63 } 64 add(st[top-1],st[top]),top--; 65 } 66 if (st[top]!=now) st[++top]=now; 67 } 68 while (--top) add(st[top],st[top+1]); 69 dp(1),printf("%lld\n",f[1]); 70 } 71 int main() 72 { 73 scanf("%d",&n); 74 for (int i=1,x,y,z;i<n;i++) scanf("%d%d%d",&x,&y,&z),insert(x,y,z); 75 mi[0]=1; for (int i=1;i<20;i++) mi[i]=mi[i-1]*2; 76 mn[1]=inf,pre(1),scanf("%d",&m); 77 for (int i=1;i<=m;i++) solve(); 78 }