BZOJ 2286 消耗战 (虚树+树形DP)
给出一个n节点的无向树,每条边都有一个边权,给出m个询问,
每个询问询问ki个点,问切掉一些边后使得这些顶点无法与顶点1连接。
最少的边权和是多少。
(n<=250000,sigma(ki)<=500000)
考虑树形DP,我们令mn[i]表示i节点无法与1节点相连切除的最小权值。
显然有mn[i]=min(E(fa,i),mn[fa]).
大致就是i到1的简单路径上的最小边。
我们对于每个询问。把询问的点不妨称为关键点。
令dp[i]表示i节点不能与子树的关键点连接切掉的最小权值。
那么有,如果son[i]是关键点,则dp[i]+=E(i,son(i)).
如果son[i]不是关键点,则dp[i]+=min(dp[son(i)],E(i,son(i))).
考虑最坏每次只询问一个点,则复杂度为O(n*sigma(ki)).显然无法承受。
我们观察到sigma(ki)有限制,这启发了我们构造一颗新树,这棵树称为虚树。
我们把每个节点和每对节点的lca单独拉出来模仿原来的树的形态构造一颗虚树。
这样再在这颗新树上进行树形DP。
构造这棵树的核心思想是每次维护一条最右边的链。
首先把关键点按dfs序排序。
然后相邻的点取lca。
再单调栈维护一下最右边的链就ok啦。
# include <stdio.h> # include <string.h> # include <stdlib.h> # include <iostream> # include <vector> # include <queue> # include <stack> # include <map> # include <math.h> # include <algorithm> using namespace std; # define lowbit(x) ((x)&(-x)) # define pi acos(-1.0) # define MAXN 250005 # define eps 1e-5 # define MAXM 500001 # define MOD 1000000007 # define INF 1000000000 # define mem(a,b) memset(a,b,sizeof(a)) # define FOR(i,a,n) for(int i=a; i<=n; ++i) # define FO(i,a,n) for(int i=a; i<n; ++i) # define bug puts("H"); typedef long long LL; typedef unsigned long long ULL; int _MAX(int a, int b){return a>b?a:b;} int _MIN(int a, int b){return a>b?b:a;} int Scan() { int res=0, flag=0; char ch; if((ch=getchar())=='-') flag=1; else if(ch>='0'&&ch<='9') res=ch-'0'; while((ch=getchar())>='0'&&ch<='9') res=res*10+(ch-'0'); return flag?-res:res; } void Out(int a) { if(a<0) {putchar('-'); a=-a;} if(a>=10) Out(a/10); putchar(a%10+'0'); } struct Edge{int p, next, w;}edge[MAXN<<1]; int head[MAXN], cnt=1, bin[20], ind; int id[MAXN], dep[MAXN], fa[MAXN][20], h[MAXN], st[MAXN], top; LL ans[MAXN], dp[MAXN]; void add_edge(int u, int v, int w) { if (u==v) return ; edge[cnt].p=v; edge[cnt].next=head[u]; edge[cnt].w=w; head[u]=cnt++; } void bin_init(){bin[0]=1; FO(i,1,20) bin[i]=bin[i-1]<<1;} bool comp(int a, int b){return id[a]<id[b];} void dfs(int x, int fat) { id[x]=++ind; fa[x][0]=fat; for (int i=1; bin[i]<=dep[x]; ++i) fa[x][i]=fa[fa[x][i-1]][i-1]; for (int i=head[x]; i; i=edge[i].next) { int v=edge[i].p; if (v==fat) continue; dep[v]=dep[x]+1; ans[v]=min(ans[x],(LL)edge[i].w); dfs(v,x); } } int lca(int x, int y) { if (dep[x]<dep[y]) swap(x,y); int t=dep[x]-dep[y]; for (int i=0; bin[i]<=t; ++i) if (bin[i]&t) x=fa[x][i]; for (int i=19; i>=0; --i) if (fa[x][i]!=fa[y][i]) x=fa[x][i], y=fa[y][i]; if (x==y) return x; else return fa[x][0]; } void dp_dfs(int x) { dp[x]=ans[x]; LL temp=0; for (int i=head[x]; i; i=edge[i].next) { int v=edge[i].p; dp_dfs(v); temp+=dp[v]; } head[x]=0; if (temp) dp[x]=min(dp[x],temp); } void sol() { int k, tot=0; cnt=1; scanf("%d",&k); FOR(i,1,k) h[i]=Scan(); sort(h+1,h+k+1,comp); h[++tot]=h[1]; FOR(i,2,k) if (lca(h[tot],h[i])!=h[tot]) h[++tot]=h[i]; st[++top]=1; FOR(i,1,tot) { int f=lca(h[i],st[top]); while (dep[f]<dep[st[top-1]]) add_edge(st[top-1],st[top],0), top--; add_edge(f,st[top--],0); if (f!=st[top]) st[++top]=f; st[++top]=h[i]; } while (top>1) add_edge(st[top-1],st[top],0), top--; dp_dfs(1); printf("%lld\n",dp[1]); } int main() { int n, m, u, v, w; bin_init(); n=Scan(); FO(i,1,n) u=Scan(), v=Scan(), w=Scan(), add_edge(u,v,w), add_edge(v,u,w); ans[1]=(LL)1<<60; dfs(1,0); m=Scan(); mem(head,0); while (m--) sol(); return 0; }