SDOI2011 消耗战
首先,如果这道题只有一次询问的话,那么他就是一个树形DP,只要我们分个情况,用把它的子树中所有节点全部割断的价值和割断它的价值的最小值更新答案就可以了,如果这个点是关键点一定是要割断的。不过这样每次是\(O(n)\)的,会超时。
不过因为这道题总的关键点数非常少,如果我们每次能只保留关键信息,每次\(O(k)\)的进行计算,那么就是可行的。于是我们就有了虚树。
虚树……简单的来说,他不是“不存在的树”,应该叫树上的一个联通子树,只保留了我们需要的关键信息,而虚树上的点之间,在原树上的点和边都被省略掉了,取而代之的是本题要保留的关键信息。比如这道题的关键信息就是路径上边权的最小值。虚树上的节点是给定的关键点和给定的关键点的lca。
首先说怎么构建出虚树。我们先把所有给定的关键点按照dfs序进行排序(注意如果1号点不是给定的关键点,也要加进去)。
之后开始按dfs序由小到大枚举,在每个新的元素进入的时候,求一下它和栈顶元素的lca。如果栈顶元素的深度要比lca大,那么就说明,现在枚举到的点已经和栈顶元素不在一棵子树中,也就是栈顶元素所在的子树,虚树已经建立完成。可以直接退栈。退到栈中第二个元素的深度小于lca的时候,那我们就把栈顶元素向lca连边即可。
上述操作完成后,如果当前的lca就是栈顶元素的话,直接把当前点向lca连边,并且入栈。否则先把lca加入虚树(因为这个lca和当前虚树上结点有关,也会存储关键信息,要加入虚树),把它向栈顶元素连边,再将lca入栈,再执行上面的操作。这样枚举结束之后,虚树就建好了。(这个题不需要用邻接表存虚树,直接记录虚树上每个点的父亲就可以)
以上内容如果不大理解的话可以画一个图看一看。不用担心直接退栈会导致虚树没建全,因为在直接退栈的时候,那棵子树肯定已经被遍历过了,在遍历的时候,其必然已经向上面的节点连了边。
因为这道题的树形DP不难,所以建出虚树之后,题目难度就不大了。直接按dfs序在虚树上倒着枚举,然后更新答案即可。更新答案的方法和一开始说的是一样的。
本题需要先预处理两点间最短距离(倍增),然后顺便预处理来做倍增lca。
#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<queue>
#include<cstring>
#define rep(i,a,n) for(register int i = a;i <= n;i++)
#define per(i,n,a) for(register int i = n;i >= a;i--)
#define enter putchar('\n')
#define pr pair<int,int>
#define mp make_pair
#define fi first
#define sc second
using namespace std;
typedef long long ll;
const int M = 300005;
const int N = 10000005;
const ll INF = 2e9;
ll read()
{
ll ans = 0,op = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
while(ch >='0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
return ans * op;
}
struct edge
{
int next,to,v;
}e[M<<1];
int n,head[M],ecnt,vir[M],par[M],fa[20][M],dfn[M],idx,top,sta[M],x,y,z,T,dep[M];
ll dp[M],minv[20][M],virn,w[M];
bool vis[M];
bool cmp(int x,int y) {return dfn[x] < dfn[y];}
void add(int x,int y,int z)
{
e[++ecnt] = (edge){head[x],y,z};
head[x] = ecnt;
}
void dfs(int x,int f)
{
dfn[x] = ++idx;
for(int i = 1;(1 << i) <= dep[x];i++)
{
fa[i][x] = fa[i-1][fa[i-1][x]];
minv[i][x] = min(minv[i-1][x],minv[i-1][fa[i-1][x]]);
}
for(int i = head[x];i;i = e[i].next)
{
if(e[i].to == f) continue;
minv[0][e[i].to] = e[i].v,fa[0][e[i].to] = x,dep[e[i].to] = dep[x] + 1;
dfs(e[i].to,x);
}
}
inline int LCA(int x,int y)
{
if(dep[x] < dep[y]) swap(x,y);
per(i,19,0) if(dep[x] - (1<<i) >= dep[y]) x = fa[i][x];
if(x == y) return x;
per(i,19,0) if(fa[i][x] != fa[i][y]) x = fa[i][x],y = fa[i][y];
return fa[0][x];
}
inline void build()
{
if(!vis[1]) vir[++virn] = 1;
top = 0,sort(vir+1,vir+1+virn,cmp);
int cur = virn;
rep(i,1,cur)
{
int x = vir[i];
if(!top) {sta[++top] = x,par[x] = 0;continue;}
int lca = LCA(x,sta[top]);
while(dep[sta[top]] > dep[lca])
{
if(dep[sta[top-1]] < dep[lca]) par[sta[top]] = lca;
top--;
}
if(lca != sta[top]) vir[++virn] = lca,par[lca] = sta[top],sta[++top] = lca;
sta[++top] = x,par[x] = lca;
}
sort(vir+2,vir+1+virn,cmp);
}
inline ll dis(int x,int y)
{
ll cur = INF;
if(dep[x] < dep[y]) swap(x,y);
per(i,19,0) if(dep[x] - (1 << i) >= dep[y]) cur = min(cur,minv[i][x]),x = fa[i][x];
if(x == y) return cur;
per(i,19,0) if(fa[i][x] != fa[i][y]) cur = min(cur,min(minv[i][x],minv[i][y])),x = fa[i][x],y = fa[i][y];
return min(cur,min(minv[0][x],minv[0][y]));
}
inline ll solve()
{
rep(i,2,virn) w[vir[i]] = dis(vir[i],par[vir[i]]);
rep(i,1,virn) dp[vir[i]] = 0;
per(i,virn,2)
{
int x = vir[i];
if(vis[x]) dp[par[x]] += w[x];
else dp[par[x]] += min(dp[x],w[x]);
}
return dp[1];
}
int main()
{
n = read();
rep(i,1,n-1) x = read(),y = read(),z = read(),add(x,y,z),add(y,x,z);
dfs(1,0);
T = read();
while(T--)
{
virn = read();
rep(i,1,virn) vir[i] = read(),vis[vir[i]] = 1;
build(),printf("%lld\n",solve());
rep(i,1,virn) vis[vir[i]] = 0;
}
return 0;
}