bzoj 3743 [ Coci 2015 ] Kamp —— 树形DP

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3743

一开始想到了树形DP,处理一下子树中的最小值,向上的最小值,以及子树中的最长路和向上的最长路,就可以得到答案,可以DP;

然而写着写着写不下去了,不会求向上最小值和最长路;

于是看看TJ,原来要再记录一个次长路!

然而写挫了,写不下去了...

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const maxn=5e5+5;
int n,m,hd[maxn],ct,fa[maxn],ans[maxn];
ll f[maxn][3],g[maxn][3],t[maxn],l[maxn][3];
bool vis[maxn];
struct N{
    int to,nxt,w;
    N(int t=0,int n=0,int w=0):to(t),nxt(n),w(w) {}
}ed[maxn<<1];
void add(int x,int y,int z){ed[++ct]=N(y,hd[x],z); hd[x]=ct;}
void dfs(int x,int ft)
{
    fa[x]=ft;
    for(int i=hd[x],u;i;i=ed[i].nxt)
    {
        if((u=ed[i].to)==ft)continue;
        dfs(u,x);
        vis[x]|=vis[u];
    }
}
void dp(int x)
{
    for(int i=hd[x],u;i;i=ed[i].nxt)
    {
        if((u=ed[i].to)==fa[x])continue;
        dp(u);
        f[x][0]+=(f[u][0]+2*vis[u]*ed[i].w);
        if(l[x][0]<l[u][0]+vis[u]*ed[i].w)
        {
            l[x][1]=l[x][0]; t[x]=u;
            l[x][0]=l[u][0]+vis[u]*ed[i].w;
        }
        else if(l[x][1]<l[u][0]+vis[u]*ed[i].w)l[x][1]=l[u][0]+vis[u]*ed[i].w;
    }
    f[x][1]=f[x][0]-l[x][0];
}
void dp2(int x)
{
    for(int i=hd[x],u;i;i=ed[i].nxt)
    {
        if((u=ed[i].to)==fa[x])continue;
        ans[x]=min(f[x][0]+g[x][1],f[x][1]+g[x][0]);
        g[u][0]=g[x][0]+f[x][0]-f[u][0];
        if(u==t[x])g[u][1]=min(g[x][1]+f[x][0]-f[u][0],g[u][0]-l[x][1])-ed[i].w;
        else g[u][1]=min(g[x][1]+f[x][0]-f[u][0],g[u][0]-l[x][0]);
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1,x,y,z;i<n;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z); add(y,x,z);
    }
    for(int i=1,x;i<=m;i++)
        scanf("%d",&x),vis[x]=1;
    dfs(1,0); dp(1); dp2(1);
    for(int i=1;i<=n;i++)printf("%lld\n",ans[i]);
    return 0;
}

然后滚去看TJ,再次学习了一下优美的树形DP写法...

dis[x] 表示总路长,f[x][0] 表示子树中最长路,f[x][1] 表示子树中次长路,f[x][2] 表示向上最长路;

然后优美 dfs 即可...

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const maxn=5e5+5,inf=0x3f3f3f3f;
int n,m,hd[maxn],ct,siz[maxn];
ll dis[maxn],f[maxn][4];
bool vis[maxn];
struct N{
    int to,nxt,w;
    N(int t=0,int n=0,int w=0):to(t),nxt(n),w(w) {}
}ed[maxn<<1];
void add(int x,int y,int z){ed[++ct]=N(y,hd[x],z); hd[x]=ct;}
void dfs(int x,int fa)
{
//    dis[x]=(vis[x]?0:-inf);
//    siz[x]=(vis[x]?1:0);
    siz[x]=vis[x]; dis[x]=0;
    for(int i=hd[x],u;i;i=ed[i].nxt)
    {
        if((u=ed[i].to)==fa)continue;
        dfs(u,x); siz[x]+=siz[u];
//        dis[x]+=dis[u]+ed[i].w; 
        dis[x]+=dis[u]+(siz[u]?ed[i].w:0);
    }
}
void dfs2(int x,int fa)
{
    for(int i=hd[x],u;i;i=ed[i].nxt)
    {
        if((u=ed[i].to)==fa)continue;
        dis[u]=dis[x];//
        if(siz[u])dis[u]-=ed[i].w;
        if(siz[u]<m)dis[u]+=ed[i].w;//siz[u]==0
        dfs2(u,x);
    }
}
void dfs3(int x,int fa)
{
    f[x][0]=f[x][1]=(vis[x]?0:(ll)-inf*inf);
    for(int i=hd[x],u;i;i=ed[i].nxt)
    {
        if((u=ed[i].to)==fa)continue;
        dfs3(u,x);
        if(f[u][0]+ed[i].w>f[x][0])
        {
            f[x][1]=f[x][0];
            f[x][0]=f[u][0]+ed[i].w;
        }
        else if(f[u][0]+ed[i].w>f[x][1])f[x][1]=f[u][0]+ed[i].w;
    }
}
void dfs4(int x,int fa)
{
    for(int i=hd[x],u;i;i=ed[i].nxt)
    {
        if((u=ed[i].to)==fa)continue;
        if(f[u][0]+ed[i].w==f[x][0])f[u][2]=f[x][1]+ed[i].w;
        else f[u][2]=f[x][0]+ed[i].w;
        f[u][2]=max(f[u][2],f[x][2]+ed[i].w);
        dfs4(u,x);
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1,x,y,z;i<n;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z); add(y,x,z);
    }
    for(int i=1,x;i<=m;i++)scanf("%d",&x),vis[x]=1;
    dfs(1,0); dfs2(1,0); dfs3(1,0);
    f[1][2]=(vis[1]?0:(ll)-inf*inf); dfs4(1,0);
    for(int i=1;i<=n;i++)printf("%lld\n",2*dis[i]-max(f[i][0],f[i][2]));
    return 0;
}

 

posted @ 2018-07-25 15:28  Zinn  阅读(331)  评论(0编辑  收藏  举报