Luogu-2495 [SDOI2011]消耗战
虚树第一题
对于每次询问的点建立一棵虚树,然后在树上DP,一个点的答案就是这个点的父边切断的代价与所有儿子切断的代价去最小值,当然如果这个节点是资源点则必须切父边
注意在虚树上一条边的代价应该是中间所有边代价的最小值,在这道题里可以用到根节点边的最小值
建虚树的时候可以不去建那些在其他资源点下面的资源点,他们不会对答案造成影响,并且这样的话资源点都是叶节点,就不用给资源点打标记了23333
注意1点不能断,特判一下就好了。
#include<map>
#include<queue>
#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn=3e5+100,maxm=5e5+100;
struct node{
int fa,dep,top,son,siz;
}tre[maxn];
int head[maxn],nex[maxm],v[maxm],w[maxm],num=1,fee[maxn];
int n,t,tim,m,p[maxn],dfn[maxn],top,st[maxn],a,b,c;
vector<int>e[maxn];
void dfs1(int x,int fa,int dep){
tre[x].dep=dep;
tre[x].fa=fa;
tre[x].siz=1;
dfn[x]=++tim;
for(int i=head[x];i;i=nex[i])
if(v[i]!=fa){
fee[v[i]]=min(fee[x],w[i]);
dfs1(v[i],x,dep+1);
if(tre[v[i]].siz>tre[tre[x].son].siz)
tre[x].son=v[i];
tre[x].siz+=tre[v[i]].siz;
}
}
void dfs2(int x,int fa,int top){
tre[x].top=top;
if(tre[x].son) dfs2(tre[x].son,x,top);
for(int i=head[x];i;i=nex[i])
if(v[i]!=fa&&v[i]!=tre[x].son)
dfs2(v[i],x,v[i]);
}
int Lca(int x,int y){
int fx=tre[x].top,fy=tre[y].top;
while(fx!=fy){
if(tre[fx].dep<tre[fy].dep) swap(x,y),swap(fx,fy);
x=tre[fx].fa;
fx=tre[x].top;
}
return tre[x].dep>tre[y].dep?y:x;
}
void add(int x,int y,int z){
v[++num]=y;
w[num]=z;
nex[num]=head[x];
head[x]=num;
v[++num]=x;
w[num]=z;
nex[num]=head[y];
head[y]=num;
}
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
void insert(int x){
if(top==1){
st[++top]=x;
return;
}
int lca=Lca(st[top],x);
if(lca==st[top]) return; //DP方便
while(top>1&&dfn[st[top-1]]>=dfn[lca])
e[st[top-1]].push_back(st[top]),top--;
if(lca!=st[top]) e[lca].push_back(st[top]),st[top]=lca;
st[++top]=x;
}
ll dp(int x){
if(e[x].size()==0) return 1ll*fee[x];
ll tot=0;
for(int i=0;i<e[x].size();i++)
tot+=dp(e[x][i]);
e[x].clear();
if(x==1) return tot;
else return min(tot,1ll*fee[x]);
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++)
scanf("%d%d%d",&a,&b,&c),add(a,b,c);
dfs1(1,1,1);
dfs2(1,1,1);
scanf("%d",&t);
while(t--){
scanf("%d",&m);
for(int i=1;i<=m;i++)
scanf("%d",&p[i]);
sort(p+1,p+m+1,cmp);
st[top=1]=1;
for(int i=1;i<=m;i++) insert(p[i]);
while(top>1) e[st[top-1]].push_back(st[top]),top--;
printf("%lld\n",dp(1));
}
return 0;
}