Codeforces 337D Book of evil

    一道树形dp,写出来是因为最近也做了道类似的.这题是看了分析的思路才做出来的,但感觉很多这样的dp都是利用类似的性质.像这题的话distDown很好想,但distUp的时候就很难想了,其实只要抓住distUp的必然经过父结点或者它的兄弟经过父结点,这周二的多校的那道也是类似的.但是要在线性时间里求出兄弟结点的时候就要注意,我们不可能遍历这个点的所有兄弟结点,所以好的办法就是存最大的两个,当该点是最大的,就用次大的算,其余的都用最大的算.多校的那个也是类似的,不过要存最大,次大,次次大,是有点麻烦.下面贴一记代码初始化为负无穷有点麻烦的样子- -0

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<vector>
#define maxn 100000
using namespace std;

int distUp[maxn+20];
int distDown[maxn+20];
int dp[maxn+20][2];
int vis[maxn+20];
vector<int> G[maxn+20];
int n,m,d;

void dfs1(int u)
{
    vis[u]=1;
    for(int i=0;i<G[u].size();i++)
    {
        int v=G[u][i];
        if(vis[v]) continue;
        vis[v]=1;
        dfs1(v);
        if(distDown[v]+1>dp[u][1]){
            dp[u][1]=distDown[v]+1;
            if(dp[u][1]>dp[u][0]){
                swap(dp[u][0],dp[u][1]);
            }
        }
    }
    distDown[u]=max(distDown[u],dp[u][0]);
}

void dfs2(int u)
{
    vis[u]=1;
    for(int i=0;i<G[u].size();i++)
    {
        int v=G[u][i];
        if(vis[v]) continue;
        if(distDown[v]+1==dp[u][0]){
            distUp[v]=max(max(dp[u][1]+1,distUp[v]),distUp[u]+1);
        }
        else{
            distUp[v]=max(max(dp[u][0]+1,distUp[v]),distUp[u]+1);
        }
        dfs2(v);
    }
}

void init(int n)
{
    for(int i=0;i<=n;i++){
        G[i].clear();
    }
    memset(mark,0,sizeof(mark));
    fill(distUp,distUp+n+1,-0x3fffffff);
    fill(distDown,distDown+n+1,-0x3fffffff);
    for(int i=0;i<=n;i++){
        dp[i][0]=-0x3fffffff;dp[i][1]=-0x3fffffff;
    }
}

int main()
{
    while(cin>>n>>m>>d)
    {
        init(n);
        int tmp;
        for(int i=0;i<m;i++){
            scanf("%d",&tmp);
            distDown[tmp]=0;
            distUp[tmp]=0;
        }
        int tu,tv;
        for(int i=0;i<n-1;i++){
            scanf("%d%d",&tu,&tv);
            G[tu].push_back(tv);
            G[tv].push_back(tu);
        }
        memset(vis,0,sizeof(vis));dfs1(1);
        memset(vis,0,sizeof(vis));dfs2(1);
        int ans=0;
        for(int i=1;i<=n;i++){
            if(distUp[i]<=d&&distDown[i]<=d){
                ans++;
            }
        }
        printf("%d\n",ans);
    }
    return 0;
}

 

posted @ 2013-08-18 14:18  chanme  阅读(343)  评论(0编辑  收藏  举报