bzoj 3829: [Poi2014]FarmCraft 树形dp+贪心

题意:

$mhy$ 住在一棵有 $n$ 个点的树的 $1$ 号结点上,每个结点上都有一个妹子。

$mhy$ 从自己家出发,去给每一个妹子都送一台电脑,每个妹子拿到电脑后就会开始安装 $zhx$ 牌杀毒软件,第 $i$ 个妹子安装时间为 $Ci$。

树上的每条边 $mhy$ 能且仅能走两次,每次耗费 $1$ 单位时间。$mhy$ 送完所有电脑后会回自己家里然后开始装 $zhx$ 牌杀毒软件。

卸货和装电脑是不需要时间的。

求所有妹子和 $mhy$ 都装好 $zhx$ 牌杀毒软件的最短时间。

 

题解:由于每条边最多走两次,所以如果进入点 $x$,必须要遍历完 $x$ 的所有子节点才能出来,我们考虑树形dp.

令 $f[i]$ 表示进入点 $i$ ,安装完 $i$ 子树中所有电脑的最小时刻,$size[i]$ 表示 $i$ 点子树中节点数量.

那么,对于点 $i$ 来说,我们就是要安排一个遍历 $i$ 点所有儿子的顺序,使得:

$max(f[1]+1,2size[1]+f[2]+1,2size[1]+2size[2]+f[3]+1,.....\sum_{i=1}^{n-1}size[i]+f[n]+1)$ 的最大值最小.

但是,我们并不知道该如何安排遍历儿子的顺序,但是我们可以考虑只有两个儿子的情况,然后发现:

若有 $i,j$ 而 $f[i]-2size[i]<f[j]-2size[j]$,则 $j$ 在 $i$ 之前访问更优.

对儿子排完序后依次累加即可.     

#include <bits/stdc++.h> 
#define N 500004   
#define LL long long 
#define setIO(s) freopen(s".in","r",stdin)      
using namespace std; 
char *p1,*p2,buf[100000];
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int rd() {int x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;}
struct data 
{  
    int f,size,id; 
    data(int f=0,int size=0,int id=0):f(f),size(size),id(id){}    
};         
bool cmp(data a,data b) 
{
    return a.f-2*a.size==b.f-2*b.size?a.f>b.f:a.f-2*a.size>b.f-2*b.size;       
}  
int n,edges; 
vector<data>G[N]; 
int hd[N],to[N<<1],nex[N<<1],val[N],f[N],size[N]; 
void add(int u,int v) 
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; 
}   
void dfs(int u,int ff) 
{    
    size[u]=1; 
    for(int i=hd[u];i;i=nex[i])    
    { 
        int v=to[i]; 
        if(v==ff) continue;     
        dfs(v,u);   
        G[u].push_back(data(f[v]+1,size[v],v)); 
        size[u]+=size[v];   
    }
    sort(G[u].begin(),G[u].end(),cmp);    
    int cur=0;          
    if(u!=1) f[u]=val[u];                           
    for(int i=0;i<G[u].size();++i)   
    {   
        f[u]=max(f[u],cur+G[u][i].f); 
        cur+=2*G[u][i].size;                         
    }        
}   
int main() 
{
    // setIO("input"); 
    int i,j; 
    n=rd(); 
    for(i=1;i<=n;++i)   val[i]=rd(); 
    for(i=1;i<n;++i) 
    {
        int u,v; 
        u=rd(),v=rd();   
        add(u,v),  add(v,u);   
    }
    dfs(1,0);     
    f[1]=max(f[1],  size[1]*2-2+val[1]);     
    printf("%d\n",f[1]);   
    return 0;    
}

  

posted @ 2019-11-02 15:35  EM-LGH  阅读(131)  评论(0编辑  收藏  举报