[HEOI2014] 大工程
「题意」给你一棵树,每次询问若在在选中的k个点两两连接无相边,边权为原来树上的点对距离,求这些边的:1)权值和 2)最短的边 3)最长的边。所有k之和$\le$2*n。
「分析」虚树模板题。(但是独立写出来还是很振奋人心的合)直接考虑对虚树dp,设pmn[x]为x到x的子树内的关键点的最短距离,pmx[x]为最长距离,sum[x]为x到子树内所有关键点的距离之和。这些都很好处理。统计答案利用树形dp的常用技巧——有当前子树和以前的子树进行组合。
「实现」
/*
写对啦!woc
*/
#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int N=1e6+10;
const int inf=0x3f3f3f3f;
int n,q,k;
int dfn[N],dep[N],fa[N][20];
vector<int> e[N];
void pre(int x,int pa) {
static int cnt=0;
dfn[x]=++cnt;
dep[x]=dep[fa[x][0]=pa]+1;
for(int i=1; (1<<i)<=dep[x]; ++i)
fa[x][i]=fa[fa[x][i-1]][i-1];
for(unsigned i=0; i<e[x].size(); ++i)
if(e[x][i]!=pa) pre(e[x][i],x);
}
int lca(int x,int y) {
if(dep[x]<dep[y]) swap(x,y);
int dif=dep[x]-dep[y];
for(int i=19; ~i; --i)
if(dif&(1<<i)) x=fa[x][i];
if(x==y) return x;
for(int i=19; ~i; --i)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int a[N],s[N],top,cnt;
int head[N],to[N<<1],len[N<<1],last[N<<1];
void add_edge(int x,int y,int w) {
to[++cnt]=y;
len[cnt]=w;
last[cnt]=head[x];
head[x]=cnt;
}
bool cmp(int x,int y) {
return dfn[x]<dfn[y];
}
void ist(int x) {
if(top==1) {
if(x!=1) s[++top]=x;
return;
}
int t=lca(s[top],x);
if(t!=s[top]) {
for(; top>1&&dfn[s[top-1]]>=dfn[t]; top--)
add_edge(s[top-1],s[top],dep[s[top]]-dep[s[top-1]]);
if(t!=s[top]) add_edge(t,s[top],dep[s[top]]-dep[t]), s[top]=t;
}
s[++top]=x;
}
bool mark[N];
LL lmn[N],lmx[N],siz[N],sum[N];
LL pum,pmn,pmx;
void dfs(int x) {
lmn[x]=inf;
lmx[x]=-inf;
sum[x]=siz[x]=0;
if(mark[x]) {
lmx[x]=lmn[x]=0;
siz[x]=1;
}
for(int i=head[x]; i; i=last[i]) {
dfs(to[i]);
pmn=min(pmn,lmn[x]+lmn[to[i]]+len[i]);
pmx=max(pmx,lmx[x]+lmx[to[i]]+len[i]);
pum+=sum[x]*siz[to[i]]
+(siz[x]-mark[x])*sum[to[i]]
+(siz[x]-mark[x])*len[i]*siz[to[i]];
lmn[x]=min(lmn[x],lmn[to[i]]+len[i]);
lmx[x]=max(lmx[x],lmx[to[i]]+len[i]);
sum[x]+=sum[to[i]]+siz[to[i]]*len[i];
siz[x]+=siz[to[i]];
}
if(mark[x]) pum+=sum[x];
head[x]=0;
mark[x]=0;
}
void print() {
printf("asphaush tree: \n");
static queue<int> Q;
Q.push(1);
while(!Q.empty()) {
int x=Q.front(); Q.pop();
for(int i=head[x]; i; i=last[i]) {
printf("%d -> %d, length is %d\n",x,to[i],len[i]);
Q.push(to[i]);
}
}
}
void solve() {
//printf("\nnew solving case: \n");
scanf("%d",&k);
for(int i=1; i<=k; ++i) {
scanf("%d",&a[i]);
mark[a[i]]=true;
}
sort(a+1,a+k+1,cmp);
cnt=0;
s[top=1]=1;
for(int i=1; i<=k; ++i) ist(a[i]);
for(; top>1; top--)
add_edge(s[top-1],s[top],dep[s[top]]-dep[s[top-1]]);
// print();
pum=0;
pmn=inf;
pmx=-inf;
dfs(1);
printf("%lld %lld %lld\n",pum,pmn,pmx);
}
int main() {
scanf("%d",&n);
for(int x,y,i=n; --i; ) {
scanf("%d%d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
pre(1,0);
scanf("%d",&q);
while(q--) solve();
return 0;
}
/*
10
2 1
3 2
4 1
5 2
6 4
7 5
8 6
9 7
10 9
5
2
5 4
2
10 4
2
5 2
2
6 1
2
6 1
*/