【BZOJ2286】消耗战(SDOI2011)-虚树+树形DP
测试地址:消耗战
做法:本题需要用到虚树+树形DP。
这题如果只有一个询问,相信大家都会做了,比较裸的树形DP。但是询问次数很大,每次询问都DP的话,总的时间复杂度就是,无法承受。但是我们发现,总共涉及的询问点数不大,那么我们迫切需要一个关于而不是关于的算法。这时候就要拿出大杀器——虚树了。
虚树其实应该不算是一种数据结构,它是一类树上题的一种处理技巧。想法其实很简单,因为我们只询问个点,那么我们就只把这个点建在一棵树上就好了。但由于我们还要维护边上的信息,所以需要一些中间节点的支撑,这些中间节点就是询问点的LCA。我们把询问点按树上的DFS序排序,我们发现对于按DFS序排序的三个点,必定等于和中的一个。所以我们只需要求出DFS中相邻询问点的LCA即可。接下来就是建虚树了,当然我们不能直接在原树上DFS,不然时间复杂度又变了。我们按照DFS序将标记过的点放入栈中,我们需要时刻保证栈中的元素都在从根出发的一条链上。如果栈顶和当前要加入的点的LCA不等于栈顶,说明当前加入点不是栈顶的子孙,那么我们加一条从次栈顶(就是栈顶下面的一个元素)到栈顶的边,边权就是两点之间路径的最小值,可以倍增求出,然后一直下去直到当前点为某个栈中点的子孙,将当前点放入栈顶。由于每个元素仅入栈一次且出栈一次,所以复杂度是和相关的。
建出虚树后就可以在虚树上DP了,以上算法总的时间复杂度为,可以通过本题。
以下是本人代码:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const ll inf=1000000000;
int n,m,k,first[250010]={0},tot=0,tim=0;
int fa[250010][21],dep[250010],b[1000010],order[250010];
int st[250010],top=0;
int firsti[250010]={0},toti=0;
ll f[250010],mn[250010][21];
bool res[250010]={0};
struct edge
{
int v,next;
ll w;
}e[500010],ei[500010];
void insert(int a,int b,ll w)
{
e[++tot].v=b,e[tot].w=w,e[tot].next=first[a],first[a]=tot;
}
void inserti(int a,int b,ll w)
{
ei[++toti].v=b,ei[toti].w=w,ei[toti].next=firsti[a],firsti[a]=toti;
}
void init(int v)
{
order[v]=++tim;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa[v][0])
{
mn[e[i].v][0]=e[i].w;
fa[e[i].v][0]=v;
dep[e[i].v]=dep[v]+1;
init(e[i].v);
}
}
int lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for(int i=20;i>=0;i--)
if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if (x==y) return x;
for(int i=20;i>=0;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
ll findup(int x,int y)
{
ll ans=inf*inf;
for(int i=20;i>=0;i--)
if (dep[fa[y][i]]>=dep[x]) ans=min(ans,mn[y][i]),y=fa[y][i];
return ans;
}
bool cmp(int a,int b)
{
return order[a]<order[b];
}
void build()
{
sort(b+1,b+k+1,cmp);
toti=0;
for(int i=1;i<k;i++)
b[k+i]=lca(b[i],b[i+1]);
b[k<<1]=1;
sort(b+1,b+(k<<1)+1,cmp);
top=0;
for(int i=1;i<=(k<<1);i++)
if (i==1||b[top]!=b[i]) b[++top]=b[i];
k=top;
top=1;st[1]=1;
for(int i=2;i<=k;i++)
{
while (top>1&&lca(st[top],b[i])!=st[top])
{
inserti(st[top-1],st[top],findup(st[top-1],st[top]));
top--;
}
st[++top]=b[i];
}
while (top>1)
{
inserti(st[top-1],st[top],findup(st[top-1],st[top]));
top--;
}
}
void dp(int v)
{
f[v]=0;
for(int i=firsti[v];i;i=ei[i].next)
{
dp(ei[i].v);
if (res[ei[i].v]) f[v]+=ei[i].w;
else f[v]+=min(ei[i].w,f[ei[i].v]);
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int u,v;
ll w;
scanf("%d%d%lld",&u,&v,&w);
insert(u,v,w),insert(v,u,w);
}
fa[1][0]=fa[0][0]=0;
mn[1][0]=mn[0][0]=inf*inf;
dep[1]=1;dep[0]=0;
init(1);
for(int i=1;i<=20;i++)
for(int j=1;j<=n;j++)
{
fa[j][i]=fa[fa[j][i-1]][i-1];
mn[j][i]=min(mn[j][i-1],mn[fa[j][i-1]][i-1]);
}
scanf("%d",&m);
for(int i=1;i<=m;i++)
{
scanf("%d",&k);
for(int j=1;j<=k;j++)
{
scanf("%d",&b[j]);
res[b[j]]=1;
}
build();
dp(1);
printf("%lld\n",f[1]);
for(int j=1;j<=k;j++)
res[b[j]]=firsti[b[j]]=0;
}
return 0;
}