【NOIP模拟】怪兽

题面

大 M 是一只怪兽,准备到比特王国吃人。比特王国有 n 个城市,城市之间由 n-1 条无向的路径连接,通过每条路径的时间为 1。其中有 m 个特别的城市,这 m 个 城市里都各有一个大神,于是大 M 打算不管普通人,只吃掉这些大神。然而大 M 是 一只具有特别能力的怪物,它可以一开始降临到 n 个城市中的任意一个城市,同时还 有一次机会在任意两个城市间打开一个虫洞,不消耗时间就能相互到达。 大 M 想知道它最少要花多少时间来吃掉这些大神(吃的时间忽略不计),如 果你不帮它它就会吃了你。 当然,大 M 是不属于这个时空的存在,所以在他吃完所有大神之后需要回到 最初降临的城市,通过时空门返回原来的位面。虫洞只能走一次。

第 1 行输出一个整数表示为了时间最短,大 M 一开始应该降临在哪个城市 (如果有多个最短时间则输出序号最小的城市编号)。 第 2 行输出一个整数表示能达到的最短时间。1<=m<=n<=123456。

分析

这题其实挺妙,按一般树形dp的思维能做,但是这一种思路更妙。

首先很显然,所有标记了的大神点的父亲必须走,把这些点标记出来,再把一条直路上的父亲结点略去,之间将最远的祖先与大神点的边权设为中间经过的边数。

如下图,在这样一棵生成的树上的所有点都必须走,且从任意一点出发都是等效的,这解决了我们第一个问题,直接生成树上找一个序号最小的点。

而代价呢?其实就是边权之和*2-树的直径,因为有了要回到起点的条件,所以边都要走两遍,而虫洞自然建在距离最远的两点,即直径上两点。

代码

#include<bits/stdc++.h>
using namespace std;
#define N 200000
int n,m,k,p,st,cnt,cot,dis,ans;
int d[N],mark[N],siz[N],first[N],head[N];
struct email
{
    int u,v,w;
    int nxt;
}e[N*2],g[N*2];
inline void add(int u,int v)
{
    e[++cnt].nxt=first[u];first[u]=cnt;
    e[cnt].u=u;e[cnt].v=v;
}
inline void readd(int u,int v,int w)
{
    g[++cot].nxt=head[u];head[u]=cot;
    g[cot].u=u;g[cot].v=v;g[cot].w=w;
}

void dfs(int u,int fa)
{
    siz[u]=mark[u];
    for(int i=first[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if(v==fa)continue;
        dfs(v,u);
        siz[u]+=siz[v];
    }
}

void dfs1(int u,int fa,int top,int w)
{
    int dalao=0;
    for(int i=first[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if(v==fa)continue;
        if(siz[v])dalao++;
    }
    if(dalao>1||mark[u])
    {
        if(top)readd(u,top,w),readd(top,u,w);
        top=u;w=0;
    }
    for(int i=first[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if(v==fa||siz[v]==0)continue;
        dfs1(v,u,top,w+1);
    }
}
inline void dfs2(int u,int fa)
{
    for(int i=head[u];i;i=g[i].nxt)
    {
        int v=g[i].v,w=g[i].w;
        if(v==fa)continue;
        d[v]=d[u]+w;
        dfs2(v,u);
    }
}

int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        add(u,v);add(v,u);
    }
    for(int i=1;i<=m;i++)scanf("%d",&k),mark[k]=1;
    st=n;dfs(1,0);dfs1(1,0,0,0);
    dfs2(k,0);

    for(int i=1;i<=n;i++)
        if(d[i]>dis)dis=d[i],p=i;
    memset(d,0,sizeof(d));dfs2(p,0);
    for(int i=1;i<=n;i++)
        dis=max(dis,d[i]);
    for(int i=1;i<=cot;i++)st=min(st,min(g[i].v,g[i].u)),ans+=g[i].w;
    printf("%d\n%d\n",st,ans-dis);
    return 0;
}

 

posted @ 2018-10-19 07:49  WJEMail  阅读(330)  评论(0编辑  收藏  举报