BZOJ 2286: [Sdoi2011]消耗战(虚树+树形dp)
解题思路
看到\(\sum\limits k_i<=500000\),应该是虚树了。然后发现把虚树建好后就是个\(sb\)树形\(dp\)了。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=250010;
typedef long long LL;
const LL inf=1e18;
template<class T> void rd(T &x){
x=0;char ch=getchar();
while(!isdigit(ch)) ch=getchar();
while(isdigit(ch)) x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
}
int to[N<<1],nxt[N<<1],val[N<<1],a[N<<1],tot,dfn[N],num,tag[N];
int n,m,head[N],siz[N],top[N],son[N],fa[N];
int stk[N<<1],tp,cnt,to_[N],nxt_[N],head_[N],cnt_;
int g[N][25],zz[N][25],val_[N];
LL dep[N],f[N][3],ans;
inline void add(int bg,int ed,int z){
to[++cnt]=ed,nxt[cnt]=head[bg],val[cnt]=z,head[bg]=cnt;
}
inline void add_(int bg,int ed,int z){
// cout<<bg<<" "<<ed<<" "<<z<<endl;
to_[++cnt_]=ed,nxt_[cnt_]=head_[bg],head_[bg]=cnt_,val_[cnt_]=z;
}
void dfs1(int x,int F,LL d){
dep[x]=d;siz[x]=1;fa[x]=F;dfn[x]=++num;
int maxson=-1,u;zz[x][0]=F;
for(int i=1;i<=20;i++){
zz[x][i]=zz[zz[x][i-1]][i-1];
g[x][i]=min(g[x][i-1],g[zz[x][i-1]][i-1]);
}
for(int i=head[x];i;i=nxt[i]){
u=to[i];if(u==F) continue;g[u][0]=val[i];
dfs1(u,x,d+val[i]);siz[x]+=siz[u];
if(siz[u]>maxson) maxson=siz[u],son[x]=u;
}
}
void dfs2(int x,int topf){
top[x]=topf;if(!son[x]) return ;
dfs2(son[x],topf);
for(int i=head[x];i;i=nxt[i]){
int u=to[i];if(u==fa[x] || u==son[x]) continue;
dfs2(u,u);
}
}
inline bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
inline int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]>=dep[top[y]]) x=fa[top[x]];
else y=fa[top[y]];
}
return dep[x]>dep[y]?y:x;
}
inline int dis(int x,int y){
int ret=1e9+1;
for(int i=22;i>=0;i--){
if(dep[zz[x][i]]<dep[y]) continue;
if(!g[x][i]) continue;
ret=min(ret,g[x][i]);
x=zz[x][i];
}
return ret;
}
void dp(int x){
f[x][0]=f[x][1]=0;
for(int i=head_[x];i;i=nxt_[i]){
int u=to_[i];dp(u);
f[x][1]+=min(f[u][0],f[u][1]);
f[x][0]+=min(f[u][0],f[u][1]+val_[i]);
}
if(tag[x]) f[x][0]=inf;tag[x]=0;
}
inline void work(){
a[++tot]=1;sort(a+1,a+1+tot,cmp);int u=tot;
for(int i=2;i<=tot;i++) a[++u]=LCA(a[i],a[i-1]);
sort(a+1,a+1+u);tot=unique(a+1,a+1+u)-a-1;
sort(a+1,a+1+tot,cmp);tp=0;
stk[++tp]=a[1];cnt_=0;
for(int i=2;i<=tot;i++){
while(tp && LCA(stk[tp],a[i])!=stk[tp]) tp--;
add_(stk[tp],a[i],dis(a[i],stk[tp]));
stk[++tp]=a[i];
}
/*
for(int i=1;i<=tot;i++){
cout<<a[i]<<" :"<<endl;
for(int j=head_[a[i]];j;j=nxt_[j]){
cout<<to_[j]<<" ";
}
cout<<endl;
}
*/
dp(1);ans=f[1][0];
for(int i=1;i<=tot;i++) head_[a[i]]=0;
}
int main(){
rd(n);int x,y,z;
for(int i=1;i<n;i++){
rd(x);rd(y);rd(z);
add(x,y,z),add(y,x,z);
}
dfs1(1,0,0);dfs2(1,1);
rd(m);
while(m--){
rd(tot);
for(int i=1;i<=tot;i++) rd(a[i]),tag[a[i]]=1;
work();
printf("%lld\n",ans);
}
return 0;
}