[HEOI2014]大工程

题解:

首先建虚树是显然的

然后在上面dp

注意到有的点可能是真实不存在的

所以要认真的搞一下dp

首先分析一波

最长链是任意的

因为假如a->b b->c (a,c真实存在,b是假的)

那么a->c 一定大于b->a or c

所以这个跟求直径一样max_len1 max_len2就可以了

最短链要求是要在两个真实点之间的

所以令min_len[i]表示i的子树中真实存在的点到它的最短距离

(注意到叶子节点一定是)

求路径和就记录子树中到当前点的路径和,子树中真实点的数目就可以了

令max_len1表示

代码:

 

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define N 2000100
struct re{
    int a,b;
    ll c;
}a[N],a2[N];
int head[N],bz[N][20],bz2[N][20],dep[N],dfn[N],cnt,l,l2,head2[N],n,m,ansmin,ansmax;
int max_len1[N],max_len2[N],min_len[N],st[N],b[N];
ll sum[N],sum2[N],num[N];
bool ff[N];
#define INF 1e9
void arr(ll x,ll y)
{
    a[++l].a=head[x];
    a[l].b=y;
    head[x]=l;
}
void dfs(ll x,ll father)
{
    dfn[x]=++cnt;
    dep[x]=dep[father]+1;
    bz[x][0]=father; bz2[x][0]=1;
    ll u=head[x];
    while (u)
    {
        ll v=a[u].b;
        if (v!=father) dfs(v,x);
        u=a[u].a;
    }
}
ll lca(ll x,ll y)
{
    if (dep[x]<dep[y]) swap(x,y);
    for (ll i=19;i>=0;i--)
        if (dep[bz[x][i]]>=dep[y]) x=bz[x][i];
    if (x==y) return(x);
    for (ll i=19;i>=0;i--)
        if (bz[x][i]!=bz[y][i])
        {
            x=bz[x][i]; y=bz[y][i];
        }
    return(bz[x][0]);
}
ll query(ll x,ll y)
{
    if (dep[x]<dep[y]) swap(x,y);
    ll ans=0;
    for (ll i=19;i>=0;i--)
        if (dep[bz[x][i]]>=dep[y]) ans+=bz2[x][i],x=bz[x][i];
    if (x==y) return(ans);
    for (ll i=19;i>=0;i--)
        if (bz[x][i]!=bz[y][i])
        {
            ans+=bz2[x][i]+bz2[y][i];
            x=bz[x][i]; y=bz[y][i];
        }
    return(ans+bz2[x][0]+bz2[y][0]);
}
bool cmp(ll x,ll y)
{
    return(dfn[x]<dfn[y]);
}
ll k;
queue<ll> q;
void arr2(ll x,ll y)
{
    q.push(x);
    a2[++l2].a=head2[x];
    a2[l2].b=y;
    if (x==n+1||y==n+1) a2[l2].c=0;
    else a2[l2].c=query(x,y);
    head2[x]=l2;
}
void js(ll x,ll fa)
{
    ll u=head2[x];
    if (ff[x]) num[x]=1,min_len[x]=0;
    while (u)
    {
        ll v=a2[u].b;
        if (v!=fa)
        {
            js(v,x);
            num[x]+=num[v];
            if (max_len1[v]+a2[u].c>=max_len1[x])
            {
                max_len2[x]=max_len1[x];
                max_len1[x]=max_len1[v]+a2[u].c;
            } else
            if (max_len1[v]+a2[u].c>max_len2[x])
                max_len2[x]=max_len1[v]+a2[u].c;
            if (num[v]>0)
            {
                ansmin=min(ansmin,min_len[x]+min_len[v]+int(a2[u].c));
                min_len[x]=min(min_len[x],min_len[v]+int(a2[u].c));
            }
        }
        u=a2[u].a;
    }
    ansmax=max(ansmax,max_len1[x]+max_len2[x]);
    u=head2[x];
    while (u)
    {
        ll v=a2[u].b;
        if (v!=fa)
        {
            sum[x]+=(num[x]-num[v])*num[v]*a2[u].c+sum2[v]*(num[x]-num[v])+sum[v];
            sum2[x]+=sum2[v]+a2[u].c*num[v];
        }
        u=a2[u].a;
    }
}
void solve()
{
    ll top=0; st[++top]=n+1;
    while (!q.empty())
    {
       ll x=q.front();
       head2[x]=0,ff[x]=0,sum[x]=0,num[x]=0,sum2[x]=0;
       max_len1[x]=0,max_len2[x]=0,min_len[x]=INF,q.pop();
    }
    l2=0; ansmax=0; ansmin=INF;
    for (ll i=1;i<=k;i++)
    {
        ff[b[i]]=1;
        ll tmp=lca(b[i],st[top]);
        while (true)
        {
          if (dfn[tmp]>=dfn[st[top-1]])
          {
            if (tmp!=st[top]) arr2(st[top],tmp),arr2(tmp,st[top]);
            top--; if (tmp!=st[top]) st[++top]=tmp;
            break;
          } else
          {
            arr2(st[top-1],st[top]); arr2(st[top],st[top-1]);
            top--;
          }
        }
        if (st[top]!=b[i]) st[++top]=b[i];
    }
    while (top>1)
    {
        arr2(st[top-1],st[top]); arr2(st[top],st[top-1]);
        top--;
    }
    js(n+1,0);
    cout<<sum[n+1]<<" "<<ansmin<<" "<<ansmax<<endl;
}
int main()
{
    cin>>n;
    arr(n+1,1);
    ll x,y;
    for (ll i=1;i<=n-1;i++)
       cin>>x>>y,arr(x,y),arr(y,x);
    dfs(n+1,0);
    for (ll i=1;i<=19;i++)
        for (ll j=1;j<=n+1;j++)
        {
            bz[j][i]=bz[bz[j][i-1]][i-1];
            bz2[j][i]=bz2[j][i-1]+bz2[bz[j][i-1]][i-1];
        }
    ll q;
    cin>>q;
    for (ll i=1;i<=n+10;i++) min_len[i]=INF;
    for (ll i=1;i<=q;i++)
    {
        cin>>k;
        ll len=0;
        for (ll j=1;j<=k;j++)
        {
            cin>>x; b[++len]=x;
        }
        sort(b+1,b+1+k,cmp);
        solve();
    }
    return 0;
}

 

posted @ 2018-04-17 00:27  尹吴潇  阅读(123)  评论(0编辑  收藏  举报