[Bzoj2286]消耗战(虚树+DP)

Description

题目链接

Solution

在虚树上跑DP即可

Code

#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define ll long long
#define N 250010
using namespace std;

const ll Inf=1ll<<60;
struct info{int to,nex;ll w;}vir[N*2],e[N*2];
int n,m,tot,head[N],dfn[N],cnt;
int _log,f[N][20],dep[N];
int q[N],sta[N],top;
ll dis[N],dp[N];

inline int read(){
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

inline void Link(int u,int v,int w){
	e[++tot].to=v;e[tot].w=w;e[tot].nex=head[u];head[u]=tot;
}

inline void Link_vir(int u,int v){
	if(u==v) return;
	vir[++tot].to=v;vir[tot].nex=head[u];head[u]=tot;
}

void dfs(int u, int fa){
	dfn[u]=++cnt;
	for (int j=1;j<=_log;++j) f[u][j]=f[f[u][j-1]][j-1];

	for(int i=head[u];i;i=e[i].nex) {
		int v=e[i].to;
		if (v==fa) continue;
		f[v][0]=u;
		dis[v]=min(dis[u],e[i].w);
		dep[v]=dep[u]+1;
		dfs(v,u);
	}
}

int LCA(int u,int v){
    if(dep[u]>dep[v]) swap(u,v);
    int d=dep[v]-dep[u];
    
    for(int i=0;i<=_log;++i)
        if(d&(1<<i)) v=f[v][i];
    if(u==v) return v;
    
    for(int i=_log;i>=0;--i)
        if(f[u][i]!=f[v][i]){
            u=f[u][i];
            v=f[v][i];
        }
    return f[u][0];
}

void DP(int x){
    ll tmp=0;dp[x]=dis[x];
    for(int i=head[x];i;i=vir[i].nex){
    	int v=vir[i].to;
    	DP(v);
    	tmp+=dp[v];
    }
    head[x]=0;
    if(!tmp) dp[x]=dis[x];
    else if(tmp<dp[x]) dp[x]=tmp;
}

bool cmp(int a,int b){return dfn[a]<dfn[b];}
void solve(){
	m=read();tot=0;
	for(int i=1;i<=m;++i) q[i]=read();
	sort(q+1,q+m+1,cmp);
	cnt=1;
	for(int i=2;i<=m;++i) if(LCA(q[i],q[cnt])!=q[cnt]) q[++cnt]=q[i];
	sta[top=1]=1;
	for(int i=1;i<=cnt;++i){
		int grand=LCA(q[i],sta[top]);
		while(1){
        	if(dep[sta[top-1]]<=dep[grand]){
        		Link_vir(grand,sta[top]); top--;
        		if(sta[top]!=grand) sta[++top]=grand;
        		break;
        	}
        	Link_vir(sta[top-1],sta[top]); top--;
    	}
    	if(sta[top]!=q[i]) sta[++top]=q[i];
	}
	top--;
	while(top) Link_vir(sta[top],sta[top+1]),top--;
    DP(1);
    printf("%lld\n",dp[1]);
}

int main(){
	n=read();_log=log(n)/log(2);
	for(int i=1;i<n;++i){
		int u=read(),v=read(),w=read();
		Link(u,v,w);
		Link(v,u,w);
	}
	dis[1]=Inf;dfs(1,0);
	memset(head,0,sizeof(head));
	int k=read();while(k--) solve();
	return 0;
}
posted @ 2018-03-30 19:08  void_f  阅读(147)  评论(0编辑  收藏  举报