【bzoj3611】 大工程

http://www.lydsy.com/JudgeOnline/problem.php?id=3611 (题目链接)

搞了1天= =,感觉人都变蠢了。。。

题意

  给出一个n个节点的树,每条边边权为1,给出q个询问,每次询问K个关键点,求出这k个点之间的两两距离和、最小距离和最大距离。

solution

  构造虚树,见 http://blog.csdn.net/MashiroSky/article/details/51971718

  之后在虚树上dp,有点麻烦。

  用size[u]表示在以u为根的子树上的关键点个数,b[u]表示虚树上的节点u是否是关键点,f[u]表示以u为根的子树上所有关键点到u的距离之和,mn[u]表示以u为根的子树上距离u最近的关键点到u的距离,mx[u]表示以u为根的子树上距离u最远的关键点到u的距离。

代码

// bzoj3611
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<set>
#define MOD 1000000007
#define inf 2147483640
#define LL long long
#define free(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout);
using namespace std;
inline LL getint() {
    LL x=0,f=1;char ch=getchar();
    while (ch>'9' || ch<'0') {if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0' && ch<='9') {x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

const int maxn=1000010;
struct edge {int next,to,w;}e[maxn<<2];
int deep[maxn],fa[maxn][20],bin[20],dfn[maxn],head[maxn];
int mn[maxn],mx[maxn],s[maxn],a[maxn],b[maxn];
int n,K,cnt,top,q;
LL tot,ans1,ans2,f[maxn],size[maxn];

void insert(int u,int v) {
    if (u==v) return;
    e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt;
    e[cnt].w=deep[v]-deep[u];
}
void dfs(int u) {
    dfn[u]=++cnt;
    for (int i=1;i<20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
    for (int i=head[u];i;i=e[i].next) if (e[i].to!=fa[u][0]) {
            deep[e[i].to]=deep[u]+1;
            fa[e[i].to][0]=u;
            dfs(e[i].to);
        }
    head[u]=0;
}
int lca(int x,int y) {
    if (deep[x]<deep[y]) swap(x,y);
    int t=deep[x]-deep[y];
    for (int i=0;bin[i]<=t;i++) if (t&bin[i]) x=fa[x][i];
    for (int i=19;i>=0;i--)
        if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return x==y?x:fa[x][0];
}
bool cmp(int x,int y) {
    return dfn[x]<dfn[y];
}
void dp(int u) {
    size[u]=b[u];
    f[u]=0;
    mn[u]=b[u]?0:inf;
    mx[u]=b[u]?0:-inf;
    for (int i=head[u];i;i=e[i].next) {
        int v=e[i].to;
        dp(v);
        tot+=(f[u]+size[u]*e[i].w)*size[v]+f[v]*size[u];
        size[u]+=size[v];
        f[u]+=f[v]+e[i].w*size[v];
        ans1=min(ans1,(LL)mn[u]+mn[v]+e[i].w);
        ans2=max(ans2,(LL)mx[u]+mx[v]+e[i].w);
        mn[u]=min(mn[u],mn[v]+e[i].w);
        mx[u]=max(mx[u],mx[v]+e[i].w);
    }
    head[u]=0;
}
void build() {
    scanf("%d",&K);
    for (int i=1;i<=K;i++) scanf("%d",&a[i]),b[a[i]]=1;
    sort(a+1,a+1+K,cmp);
    s[++top]=1;
    for (int i=1;i<=K;i++) {
        int t=a[i],f=0;
        while (top>0) {
            f=lca(a[i],s[top]);
            if (top>1 && deep[f]<deep[s[top-1]])
                insert(s[top-1],s[top]),top--;
            else if (deep[f]<deep[s[top]]) {insert(f,s[top--]);break;}
            else break;
        }
        if (s[top]!=f) s[++top]=f;
        s[++top]=t;
    }
    while (--top) insert(s[top],s[top+1]);
}
int main() {
    bin[0]=1;for (int i=1;i<20;i++) bin[i]=bin[i-1]<<1;
    scanf("%d",&n);
    for (int i=1;i<n;i++) {
        int u=getint(),v=getint();
        insert(u,v);insert(v,u);
    }
    cnt=0;
    dfs(1);
    scanf("%d",&q);
    while (q--) {
        build();
        ans1=inf;ans2=-inf;tot=0;
        dp(1);
        printf("%lld %lld %lld\n",tot,ans1,ans2);
        for (int i=1;i<=K;i++) b[a[i]]=0;
    }
    return 0;
}

  

posted @ 2016-09-27 22:46  MashiroSky  阅读(270)  评论(0编辑  收藏  举报