[HEOI2014] 大工程

「题意」给你一棵树,每次询问若在在选中的k个点两两连接无相边,边权为原来树上的点对距离,求这些边的:1)权值和 2)最短的边 3)最长的边。所有k之和$\le$2*n。
「分析」虚树模板题。(但是独立写出来还是很振奋人心的合)直接考虑对虚树dp,设pmn[x]为x到x的子树内的关键点的最短距离,pmx[x]为最长距离,sum[x]为x到子树内所有关键点的距离之和。这些都很好处理。统计答案利用树形dp的常用技巧——有当前子树和以前的子树进行组合。

「实现」

/*
    写对啦!woc
*/

#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int N=1e6+10;
const int inf=0x3f3f3f3f;

int n,q,k;
int dfn[N],dep[N],fa[N][20];
vector<int> e[N];

void pre(int x,int pa) {
    static int cnt=0;
    dfn[x]=++cnt;
    dep[x]=dep[fa[x][0]=pa]+1;
    for(int i=1; (1<<i)<=dep[x]; ++i) 
        fa[x][i]=fa[fa[x][i-1]][i-1];
    for(unsigned i=0; i<e[x].size(); ++i)
        if(e[x][i]!=pa) pre(e[x][i],x);
}
int lca(int x,int y) {
    if(dep[x]<dep[y]) swap(x,y);
    int dif=dep[x]-dep[y];
    for(int i=19; ~i; --i) 
        if(dif&(1<<i)) x=fa[x][i];
    if(x==y) return x;
    for(int i=19; ~i; --i) 
        if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}

int a[N],s[N],top,cnt;
int head[N],to[N<<1],len[N<<1],last[N<<1];
void add_edge(int x,int y,int w) {
    to[++cnt]=y;
    len[cnt]=w;
    last[cnt]=head[x];
    head[x]=cnt;
}
bool cmp(int x,int y) {
    return dfn[x]<dfn[y];
}
void ist(int x) {
    if(top==1) {
        if(x!=1) s[++top]=x; 
        return;
    }
    int t=lca(s[top],x);
    if(t!=s[top]) {
        for(; top>1&&dfn[s[top-1]]>=dfn[t]; top--) 
            add_edge(s[top-1],s[top],dep[s[top]]-dep[s[top-1]]);
        if(t!=s[top]) add_edge(t,s[top],dep[s[top]]-dep[t]), s[top]=t;
    }
    s[++top]=x;
}

bool mark[N];
LL lmn[N],lmx[N],siz[N],sum[N];
LL pum,pmn,pmx;

void dfs(int x) {
    lmn[x]=inf;
    lmx[x]=-inf;
    sum[x]=siz[x]=0;
    if(mark[x]) {
        lmx[x]=lmn[x]=0;
        siz[x]=1;
    }
    for(int i=head[x]; i; i=last[i]) {
        dfs(to[i]);
        pmn=min(pmn,lmn[x]+lmn[to[i]]+len[i]);
        pmx=max(pmx,lmx[x]+lmx[to[i]]+len[i]);
        pum+=sum[x]*siz[to[i]]
            +(siz[x]-mark[x])*sum[to[i]]
            +(siz[x]-mark[x])*len[i]*siz[to[i]];
        lmn[x]=min(lmn[x],lmn[to[i]]+len[i]);
        lmx[x]=max(lmx[x],lmx[to[i]]+len[i]);
        sum[x]+=sum[to[i]]+siz[to[i]]*len[i];
        siz[x]+=siz[to[i]];
    }
    if(mark[x]) pum+=sum[x];
    head[x]=0;
    mark[x]=0;
}

void print() {
printf("asphaush tree: \n");
    static queue<int> Q;
    Q.push(1);
    while(!Q.empty()) {
        int x=Q.front(); Q.pop();
        for(int i=head[x]; i; i=last[i]) {
            printf("%d -> %d, length is %d\n",x,to[i],len[i]);
            Q.push(to[i]);
        }
    }
}

void solve() {
//printf("\nnew solving case: \n");
    scanf("%d",&k);
    for(int i=1; i<=k; ++i) {
        scanf("%d",&a[i]);
        mark[a[i]]=true;
    }
    sort(a+1,a+k+1,cmp);
    cnt=0;
    s[top=1]=1;
    for(int i=1; i<=k; ++i) ist(a[i]);
    for(; top>1; top--)
        add_edge(s[top-1],s[top],dep[s[top]]-dep[s[top-1]]);
//    print();
    pum=0;
    pmn=inf;
    pmx=-inf;
    dfs(1);
    printf("%lld %lld %lld\n",pum,pmn,pmx);
}

int main() {
    scanf("%d",&n);
    for(int x,y,i=n; --i; ) {
        scanf("%d%d",&x,&y);
        e[x].push_back(y);
        e[y].push_back(x);
    }
    pre(1,0);
    scanf("%d",&q);
    while(q--) solve();
    return 0;
}

/*
10 
2 1 
3 2 
4 1 
5 2 
6 4 
7 5
8 6 
9 7 
10 9 

5 

2 
5 4 

2 
10 4 

2 
5 2 

2 
6 1 

2 
6 1
*/
posted @ 2019-01-11 18:23  nosta  阅读(265)  评论(0编辑  收藏  举报