[SDOI2011]消耗战(虚树+树形动规)
虚树dp
虚树的主要思想:
- 不遍历没用的的节点以及没用的子树,从而使复杂度降低到\(\sum\limits k\)(k为询问的节点的总数)。
所以怎么办:
- 只把询问节点和其LCA放入询问的数组中。
1、建虚树
q.clear();
int m;
scanf("%d",&m);
for(int i=1;i<=m;++i){
int x;
scanf("%d",&x);
v[x]=1;
q.push_back(x);
}
sort(q.begin(),q.end(),cmp);
for(int i=0;i<m-1;++i){
q.push_back(LCA(q[i],q[i+1]));
}
q.push_back(1);
sort(q.begin(),q.end());
q.erase(unique(q.begin(),q.end()),q.end());
sort(q.begin(),q.end(),cmp);
q是一个vector,我们开始先对所有节点按欧拉序(即深度优先搜索是访问的顺序)排序,然后对每两个相邻的节点将LCA放入q中(可知这样一定会将所有有效节点放入q中)。然后一波去重,再按欧拉序排序即可。
如果你还不会LCA的话请到这里
2.遍历虚树
然后我们得到了一个遍历表,向深度优先搜索一样搜一遍即可。
这里要注意每个节点x只有当他下一个节点y是他的子节点时(即\(dfn[x]+size[x]>=dfn[y]\)时,其中\(dfn\)为欧拉序,\(size\)为子树大小)才访问下一个节点,并用下一个点的信息更新当前节点。
对于两点间的最短树边,我们可以用倍增来寻找(当然也可以用st表\(O(1)\)求,但这题并不要求)。
long long getmin(int x,int lca){
int ret=1e18;
for(int i=t-1;~i;--i){
if(dep[x]-(1<<i)>=dep[lca]){
ret=min(ret,c[x][i]);
x=fa[x][i];
}
}
return ret;
}
3.树型DP
对于每个节点,如果他一定要被割掉,则当前点的最小花费为当前点到父亲节点的最小树边,否则为所有子节点的最小花费和和当前点到父亲节点的最小树边的最小值。
void dfs1(){
int x=q[it];
long long ret=0;
while(1){
if(it+1==q.size())break;
if(dfn[q[it+1]]<=dfn[x]+sz[x]-1){
int y=q[++it];
if(v[y]==1){
dfs1();
dp[y]=getmin(y,x);
}
else dp[y]=1e18,dfs1(),dp[y]=min(dp[y],getmin(y,x));
ret+=dp[y];
}else break;
}
if(ret)dp[x]=min(dp[x],ret);
}
代码中的it为当前访问到的节点在q中的编号
然后就可以写出代码了,需要注意一些初始化的细节:
#include<bits/stdc++.h>
using namespace std;
const int N=600010,t=20;
int n;
int tot,bian[N<<1],nxt[N<<1],zhi[N<<1],head[N];
void add(int x,int y,int z){
tot++,bian[tot]=y,zhi[tot]=z,nxt[tot]=head[x],head[x]=tot;
}
int dfn[N],cnt;
int fa[N][t],c[N][t],dep[N],sz[N];
vector<int>q;
bool cmp(int x,int y){
return dfn[x]<dfn[y];
}
void dfs(int x,int f){
sz[x]=1;
dfn[x]=++cnt;
dep[x]=dep[f]+1;
fa[x][0]=f;
for(int i=1;i<t;++i){
fa[x][i]=fa[fa[x][i-1]][i-1];
c[x][i]=min(c[x][i-1],c[fa[x][i-1]][i-1]);
}
for(int i=head[x];i;i=nxt[i]){
int y=bian[i];
if(y==f)continue;
c[y][0]=zhi[i];
dfs(y,x);
sz[x]+=sz[y];
}
}
int LCA(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=t-1;~i;--i){
if(dep[x]-(1<<i)>=dep[y]){
x=fa[x][i];
}
}
if(x==y)return x;
for(int i=t-1;~i;--i){
if(fa[x][i]!=fa[y][i]){
x=fa[x][i],y=fa[y][i];
}
}
return fa[x][0];
}
long long getmin(int x,int lca){
int ret=1e18;
for(int i=t-1;~i;--i){
if(dep[x]-(1<<i)>=dep[lca]){
ret=min(ret,c[x][i]);
x=fa[x][i];
}
}
return ret;
}
#define IT vector<int>::iterator
long long dp[N];
int it,v[N];
void dfs1(){
int x=q[it];
long long ret=0;
while(1){
if(it+1==q.size())break;
if(dfn[q[it+1]]<=dfn[x]+sz[x]-1){
int y=q[++it];
if(v[y]==1){
dfs1();
dp[y]=getmin(y,x);
}
else dp[y]=1e18,dfs1(),dp[y]=min(dp[y],getmin(y,x));
ret+=dp[y];
}else break;
}
if(ret)dp[x]=min(dp[x],ret);
}
int main(){
dfn[0]=1e9;
cin>>n;
for(int i=1;i<n;++i){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
dfs(1,0);
int T;
cin>>T;
while(T--){
q.clear();
int m;
scanf("%d",&m);
for(int i=1;i<=m;++i){
int x;
scanf("%d",&x);
v[x]=1;
q.push_back(x);
}
sort(q.begin(),q.end(),cmp);
for(int i=0;i<m-1;++i){
q.push_back(LCA(q[i],q[i+1]));
}
q.push_back(1);
sort(q.begin(),q.end());
q.erase(unique(q.begin(),q.end()),q.end());
sort(q.begin(),q.end(),cmp);
m=q.size();it=0;
dp[q[0]]=1e18;
dfs1();
for(int i=0;i<m;++i){
v[q[i]]=0;
}
printf("%lld\n",dp[*q.begin()]);
}
}