codeforces1029 E.Tree with Small Distances

题目链接

题意:给出一个 \(n\) 个结点的树,问如何选择结点进行连线使得结点\(1\)到其他所有结点的最短距离都小于等于 \(2\)

题解:这道题倒是自己想出来了。

首先有个结论:连线都是从结点1向其他结点连线,因为这样总是最优的。

由题意可知,当结点\(1\)向某个节点 \(u\) 连线后,与结点 \(u\) 直接相连的所有结点都能满足条件。

考虑树形dp。

\(d[u][0]\):结点u的子结点都满足条件,但是结点u不满足

\(d[u][1]\):结点u的子树(包括 \(u\) 自己)都满足条件,但结点u没有被连线

\(d[u][2]\):结点u的子树(包括 \(u\) 自己)都满足条件,且结点u与节点1连线

那么,对于叶子结点:

\(d[u][0]=0,d[u][1]=inf,d[u][2]=1\)

对于非叶子结点 \(u\) 和它的子结点\(v\):

\(d[u][0]=\sum d[v][1];\)

\(d[u][2]=\sum min(d[v][0],min(d[v][1],d[v][2]))+1\)

\(d[u][1]=\sum min(d[v][1],d[v][2]);\) //在这个式子中必须保证有至少一个子结点\(v\) 取的是 \(d[v][2]\)

代码:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define fi first
#define se second
#define dbg(...) cerr<<"["<<#__VA_ARGS__":"<<(__VA_ARGS__)<<"]"<<endl;
typedef vector<int> VI;
typedef long long ll;
typedef pair<int,int> PII;
const int inf=0x3fffffff;
const ll mod=1000000007;
const int maxn=2e5+10;
int head[maxn],dep[maxn];
int tol;
int d[maxn][3];
struct edge
{
    int to,next;
}e[maxn*2];

void add(int u,int v)
{
    e[++tol].to=v,e[tol].next=head[u],head[u]=tol;
    e[++tol].to=u,e[tol].next=head[v],head[v]=tol;
}
int cnt[maxn]; //节点度数
//d[u][0]-结点u的子结点都满足条件,但是结点u不满足
//d[u][1]-结点u的子树(包括u自己)都满足条件,但结点u没有被连线
//d[u][2]-结点u的子树(包括u自己)都满足条件,且结点u与节点1连线
void dfs(int u,int f)
{
    dep[u]=dep[f]+1;
    d[u][0]=0;
    d[u][1]=cnt[u]==1? 1e6:0;
    d[u][2]=1;
    int mi=1e9;
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==f) continue;
        dfs(v,u);
        d[u][0]+=d[v][1];
        d[u][1]+=min(d[v][1],d[v][2]);
        mi=min(mi,d[v][2]-d[v][1]);
        d[u][2]+=min(d[v][0],min(d[v][1],d[v][2]));
        rep(i,0,3) if(d[u][i]>1e6) d[u][i]=1e6;
    }
    if(mi>0) d[u][1]+=mi;
    rep(i,0,3) if(d[u][i]>1e6) d[u][i]=1e6;
}

int main()
{
    int n;
    scanf("%d",&n);
    rep(i,1,n)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        cnt[u]++,cnt[v]++;
        add(u,v);
    }
    dfs(1,0);
    int ans=0;
    rep(i,1,n+1) if(dep[i]==3) ans+=min(d[i][0],min(d[i][1],d[i][2]));
    printf("%d\n",ans);
    return 0;
}
posted @ 2018-08-30 20:39  tarjan's  阅读(112)  评论(0编辑  收藏  举报