bzoj2286: [Sdoi2011]消耗战 虚树
在一场战争中,战场由n个岛屿和n-1个桥梁组成,保证每两个岛屿间有且仅有一条路径可达。现在,我军已经侦查到敌军的总部在编号为1的岛屿,而且他们已经没有足够多的能源维系战斗,我军胜利在望。已知在其他k个岛屿上有丰富能源,为了防止敌军获取能源,我军的任务是炸毁一些桥梁,使得敌军不能到达任何能源丰富的岛屿。由于不同桥梁的材质和结构不同,所以炸毁不同的桥梁有不同的代价,我军希望在满足目标的同时使得总代价最小。
侦查部门还发现,敌军有一台神秘机器。即使我军切断所有能源之后,他们也可以用那台机器。机器产生的效果不仅仅会修复所有我军炸毁的桥梁,而且会重新随机资源分布(但可以保证的是,资源不会分布到1号岛屿上)。不过侦查部门还发现了这台机器只能够使用m次,所以我们只需要把每次任务完成即可。
Input
第一行一个整数n,代表岛屿数量。
接下来n-1行,每行三个整数u,v,w,代表u号岛屿和v号岛屿由一条代价为c的桥梁直接相连,保证1<=u,v<=n且1<=c<=100000。
第n+1行,一个整数m,代表敌方机器能使用的次数。
接下来m行,每行一个整数ki,代表第i次后,有ki个岛屿资源丰富,接下来k个整数h1,h2,…hk,表示资源丰富岛屿的编号。
Output
输出有m行,分别代表每次任务的最小代价。
题意:简单来说就是每次给k个点让1和给定点不能联通,求最小花费
解法:虚树,虚树就是把所有需要操作的点和他们的lca抠出来,然后直接在上面dp,保证每次抠出来的点不超过2k,
建树:先把需操作的点按dfs序排序,然后维护一个栈,表示从根到栈顶的链包含的需操作的点,
考虑栈顶元素是p,栈第二个元素是q,需插入的点是x,
如果lca(p,x)为p,代表x在p子树中,直接入栈即可,又lca(p,x)不可能为p(因为按dfs序插入的)
1.如果lca(p,x)的深度比q的深度小,那么链接栈顶和栈次顶,pop栈顶
2.如果lca(p,x)的深度比q的深度大或相同,那么链接lca和p,然后pop栈顶,lca入栈,x入栈(当lca(p,x)和q深度相同,那么不用加入lca,q就是lca)
最后把栈中元素全部出栈,并链接
最后在虚树上dp,如果当前点要被ban,那么肯定删上面的一条边,否则选上面的边和下面的dp之和中小的
(虚树题目很明显,多次查询,每次选取一些点操作,总的点不会很大,那么就可以用虚树)
/**************************************************************
Problem: 2286
User: walfy
Language: C++
Result: Accepted
Time:12852 ms
Memory:61228 kb
****************************************************************/
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast,no-stack-protector")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
//#pragma GCC optimize("unroll-loops")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define pi acos(-1.0)
#define ll long long
#define vi vector<int>
#define mod 1000000007
#define ld long double
#define C 0.5772156649
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1
#define pil pair<int,ll>
#define pli pair<ll,int>
#define pii pair<int,int>
#define cd complex<double>
#define ull unsigned long long
#define base 1000000000000000000
#define fio ios::sync_with_stdio(false);cin.tie(0)
using namespace std;
const double eps=1e-6;
const int N=250000+10,maxn=500000+10,inf=0x3f3f3f3f,INF=0x3f3f3f3f3f3f3f3f;
vector<pair<int,int> >v[N];
vi in;
struct edge{
int to,Next,c;
}e[maxn];
int cnt,head[N],deep[N];
int fa[20][N],mi[20][N],l[N],res;
void init()
{
cnt=res=0;
memset(head,-1,sizeof head);
memset(mi,inf,sizeof mi);
}
void add(int u,int v,int c)
{
e[cnt].to=v;
e[cnt].c=c;
e[cnt].Next=head[u];
head[u]=cnt++;
}
void dfs(int u,int f,int dep)
{
l[u]=++res;
deep[u]=dep;
for(int i=head[u];~i;i=e[i].Next)
{
int x=e[i].to;
if(x!=f)
{
fa[0][x]=u;mi[0][x]=e[i].c;
dfs(x,u,dep+1);
}
}
}
void gao(int n)
{
for(int i=1;i<20;i++)
{
for(int j=1;j<=n;j++)
{
fa[i][j]=fa[i-1][fa[i-1][j]];
mi[i][j]=min(mi[i-1][fa[i-1][j]],mi[i-1][j]);
}
}
}
int lca(int x,int y)
{
if(deep[x]>deep[y])swap(x,y);
for(int i=19;i>=0;i--)
if(((deep[y]-deep[x])>>i)&1)
y=fa[i][y];
if(x==y)return x;
for(int i=19;i>=0;i--)
{
if(fa[i][x]!=fa[i][y])
{
x=fa[i][x];
y=fa[i][y];
}
}
return fa[0][x];
}
int getmi(int a,int b)
{
if(deep[a]>deep[b])swap(a,b);
int ans=inf;
for(int i=19;i>=0;i--)
{
if(deep[fa[i][b]]>deep[a])
{
ans=min(ans,mi[i][b]);
b=fa[i][b];
}
}
return min(ans,mi[0][b]);
}
void add1(int a,int b,int c){v[a].pb(mp(b,c));in.pb(a);in.pb(b);}
int st[N],top,a[N];
ll dp[N];
void ins(int x)
{
if(!top){st[++top]=x;return ;}
int lc=lca(st[top],x);
while(top>1&&deep[st[top-1]]>deep[lc])
add1(st[top-1],st[top],getmi(st[top-1],st[top])),top--;
if(top>=1&&deep[st[top]]>deep[lc])
add1(lc,st[top],getmi(st[top],lc)),top--;
if(!top||deep[st[top]]<deep[lc])st[++top]=lc;
st[++top]=x;
}
bool cmp(int a,int b){return l[a]<l[b];}
bool ban[N];
void dfs1(int u,ll mm)
{
for(int i=0;i<v[u].size();i++)
dfs1(v[u][i].fi,v[u][i].se);
if(ban[u]){dp[u]=mm;return ;}
else
{
ll res=0;
for(int i=0;i<v[u].size();i++)
res+=dp[v[u][i].fi];
dp[u]=min(res,mm);
}
}
int main()
{
int n;
scanf("%d",&n);
init();
for(int i=1;i<n;i++)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add(a,b,c);add(b,a,c);
}
dfs(1,-1,1);gao(n);
int m;scanf("%d",&m);
for(int i=1;i<=m;i++)
{
in.clear();
int k;
scanf("%d",&k);
for(int j=0;j<k;j++)scanf("%d",&a[j]),ban[a[j]]=1;
sort(a,a+k,cmp);
top=0;ins(1);
for(int j=0;j<k;j++)ins(a[j]);
while(top>=2)add1(st[top-1],st[top],getmi(st[top-1],st[top])),top--;
dfs1(1,1e18);
printf("%lld\n",dp[1]);
for(int j=0;j<in.size();j++)
v[in[j]].clear(),dp[in[j]]=0,ban[in[j]]=0;
}
return 0;
}
/********************
10
1 5 13
1 9 6
2 1 19
2 4 8
2 3 91
5 6 8
7 5 4
7 8 31
10 7 9
1000
2 10 6
4 5 7 8 3
3 9 4 6
********************/