P2495 [SDOI2011]消耗战
虚树\(dp\)+倍增。
构建虚树:边权为原树这条链上的最小值。
状转方程:若子节点为关键点,则此边必断,\(dp[x]+=w[x->son]\)。
否则,可以选择断这条边或在子树中自行解决,\(dp[x]+=\min\{w[x->son],dp[son]\}\)。
链上最小值,可以选择树剖维护,这里使用倍增。
代码如下,仅供参考:
#include<bits/stdc++.h>
using namespace std;
const int maxn=3e5+10;
int n,m,cnt,a[maxn],tag[maxn];
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
int beg[maxn],nex[maxn<<1],to[maxn<<1],w[maxn<<1],e;
inline void add(int x,int y,int z){
e++;nex[e]=beg[x];
beg[x]=e;to[e]=y;w[e]=z;
}
int dfn[maxn],dep[maxn],f[maxn][20],val[maxn][20];
inline void dfs(int x,int fa,int pri){
dfn[x]=++cnt;dep[x]=dep[fa]+1;
f[x][0]=fa;val[x][0]=pri;
for(int i=1;i<=19;i++){
f[x][i]=f[f[x][i-1]][i-1];
val[x][i]=min(val[x][i-1],val[f[x][i-1]][i-1]);
}
for(int i=beg[x];i;i=nex[i])
if(to[i]!=fa)dfs(to[i],x,w[i]);
}
inline int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=19;i>=0;i--)
if(dep[f[x][i]]>=dep[y])x=f[x][i];
if(x==y)return x;
for(int i=19;i>=0;i--)
if(f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
return f[x][0];
}
inline int calc(int x,int y){
int res=1e9;swap(x,y);
for(int i=19;i>=0;i--)
if(dep[f[x][i]]>=dep[y]){
res=min(res,val[x][i]);
x=f[x][i];
}
return res;
}
long long dp[maxn];int siz[maxn];
inline void solve(int x){
dp[x]=0;siz[x]=tag[x];
for(int i=beg[x];i;i=nex[i]){
int t=to[i];solve(t);
siz[x]+=siz[t];
if(tag[t])dp[x]+=w[i];
else if(siz[t])dp[x]+=min(dp[t],1ll*w[i]);
}
if(!beg[x])dp[x]=1e9;
//printf("%d %d\n",x,dp[x]);
}
int st[maxn],top;
inline int cmp(int x,int y){return dfn[x]<dfn[y];}
int main(){
n=read();
int x,y,z;
for(int i=1;i<n;i++){
x=read(),y=read(),z=read();
add(x,y,z),add(y,x,z);
}
memset(val,0x3f,sizeof(val));
dfs(1,0,0);
m=read();
memset(beg,0,sizeof(beg)),e=0;
for(int T=1;T<=m;T++){
cnt=read();e=top=0;
for(int i=1;i<=cnt;i++)
a[i]=read();
for(int i=1;i<=cnt;i++)
tag[a[i]]=1;
sort(a+1,a+1+cnt,cmp);
st[++top]=1;beg[1]=0;
for(int i=1;i<=cnt;i++){
if(a[i]==1)continue;
int anc=lca(a[i],st[top]);
if(anc!=st[top]){
while(top>1&&dep[st[top-1]]>dep[anc])
add(st[top-1],st[top],calc(st[top-1],st[top])),top--;
if(anc!=st[top-1])beg[anc]=0,add(anc,st[top],calc(anc,st[top])),st[top]=anc;
else add(anc,st[top],calc(anc,st[top])),top--;
}
beg[a[i]]=0;st[++top]=a[i];
}
for(int i=1;i<top;i++)
add(st[i],st[i+1],calc(st[i],st[i+1]));
solve(1);printf("%lld\n",dp[1]);
for(int i=1;i<=cnt;i++)
tag[a[i]]=0;
}
return 0;
}
深深地感到自己的弱小。