树形dp

1.codeforces 816 E. Karen and Supermarket

 题意:有n件商品,每件有价格ci,优惠券di,对于i>=2,使用di的条件为:xi的优惠券需要被使用,问初始金钱为b时 最多能买多少件商品? n<=5000,ci,di,b<=1e9

思路:

根据限制连边 转化为背包问题
dp[i][j][0/1] 表示以i为根的树中选了j件,第i件打不打折的最优值。
转移时枚举儿子选了多少 他就选了j减多少
最后统计答案 第一个小于等于限制钱数的就是。

#include<iostream>
#include<cstdio>
#include<cstring>

#define N 5007
#define inf 0x3f3f3f3f

using namespace std;
int n,b,ans,cnt,x;
int head[N],dp[N][N][2],siz[N],pr[N],pd[N];
struct edge
{
    int v,next;
}e[N<<1];

inline int read()
{
    int x=0,f=1;char c=getchar();
    while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

inline void add(int u,int v)
{
    e[cnt].v=v;e[cnt].next=head[u];head[u]=cnt++;
}

void dfs(int u)
{
    siz[u]=1;
    dp[u][0][0]=0;
    dp[u][1][0]=pr[u];
    dp[u][1][1]=pr[u]-pd[u];
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].v; dfs(v);
        for(int j=siz[u];j>=0;j--)
        {
            for(int k=1;k<=siz[v];k++)
            {
                dp[u][j+k][0]=min(dp[u][j+k][0],dp[u][j][0]+dp[v][k][0]);
                dp[u][j+k][1]=min(dp[u][j+k][1],dp[u][j][1]+dp[v][k][0]);
                dp[u][j+k][1]=min(dp[u][j+k][1],dp[u][j][1]+dp[v][k][1]);
            }
        }siz[u]+=siz[v];
    }
}

int main()
{
    n=read();b=read();
    memset(dp,0x3f,sizeof dp);
    memset(head,-1,sizeof head);
    for(int i=1;i<=n;i++)
    {
        pr[i]=read();pd[i]=read();
        if(i!=1)
        {
            x=read();add(x,i);
        }
    }dfs(1);
    for(int i=n;i>=0;i--)
    {
        if(dp[1][i][1]<=b || dp[1][i][0]<=b)
        {
            ans=i;
            break;
        }
    }
    printf("%d\n",ans);
    return 0;
}
Code

 

2.洛谷P1272

题意:一棵树上要断掉大小为P的子树,求最少断掉边数

思路:

显然树形dp
dp[i][j]:i为根断掉子树大小为j最小边数
初始化dp[u][1]=1的度数
转移时枚举当前点断掉多少,算出连到的儿子断掉多少
因为由儿子转移过来,他们之间的连边不能断
但是转移时断掉了两次,所以答案减2

#include<iostream>
#include<cstdio>
#include<cstring>

#define N 151
#define inf 0x7f7f7f7f

using namespace std;
int dp[N][N],head[N],d[N];
int n,m,ans,cnt;
struct node
{
    int u,v,next; 
}e[N<<1];

inline int read()
{
    int x=0,f=1;char c=getchar();
    while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

inline void add(int u,int v)
{
    e[++cnt].v=v;e[cnt].next=head[u];head[u]=cnt;
}

void dfs(int u,int fa)
{
    dp[u][1]=d[u];
    for(int i=head[u];i;i=e[i].next)
    {
        if(e[i].v!=fa)
        {
            dfs(e[i].v,u);
            for(int j=m;j>=1;j--)
              for(int k=1;k<=j;k++)
                dp[u][j]=min(dp[u][j],dp[e[i].v][k]+dp[u][j-k]-2);
        }
    }ans=min(ans,dp[u][m]);
}

int main()
{
    int x,y;
    memset(dp,1,sizeof dp);
    n=read();m=read();
    for(int i=1;i<n;i++)
    {
        x=read();y=read();
        add(x,y);add(y,x);
        d[x]++;d[y]++;
    }
    ans=inf;
    dfs(1,0);
    printf("%d\n",ans);
    return 0;
}
Code

 

3.洛谷P3174  https://www.luogu.org/problemnew/show/P3174

题解:https://www.cnblogs.com/L-Memory/p/9766728.html

 

#include<iostream>
#include<cstdio>
#include<cstring>

#define N 300007

using namespace std;
int f[N],head[N],son[N];
int n,m,k,ans,maxx;
struct edge
{
    int to,net;
}e[N<<1];

inline void add(int u,int v)
{
    e[++k].to=v;e[k].net=head[u];head[u]=k;
}

inline int read()
{
    int x=0,f=1;char c=getchar();
    while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

void dfs(int u,int fa)
{
    int v,bigx=0,lowx=0;
    for(int i=head[u];i;i=e[i].net)
    {
        v=e[i].to;
        if(v!=fa)
        {
            dfs(v,u);
            if(f[v]>lowx)//维护最大链与次大链 
            {
                if(f[v]>bigx)lowx=bigx,bigx=f[v];
                else lowx=f[v];
            }
            f[u]=max(f[u],f[v]+son[u]-1);
        }
    }
    ans=max(ans,lowx+bigx+son[u]-1);//(-1是因为根节点重复加了) 
}

int main()
{
    n=read();m=read();
    for(int i=1;i<=m;i++)
    {
        int u,v;
        u=read();v=read();
        add(u,v);add(v,u);
        son[u]++;son[v]++; 
    }
    for(int i=1;i<=n;i++)f[i]=1;
    dfs(1,0);
    printf("%d",ans);
}
Code

 

4.bzoj4033  https://www.lydsy.com/JudgeOnline/problem.php?id=4033

题解:https://www.cnblogs.com/L-Memory/p/9768069.html

 

#include<iostream>
#include<cstdio>
#include<cstring>

#define N 2001
#define ll long long

using namespace std;
int n,k,ans,cnt,S,T;
int head[N],siz[N];
ll f[N][N],tmp;
struct edge{
    int u,v,net;
    ll w;
}e[N<<1];

inline int read()
{
    int x=0,f=1;char c=getchar();
    while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

inline void add(int u,int v,ll w)
{
    e[++cnt].v=v;e[cnt].w=w;e[cnt].net=head[u];head[u]=cnt;
}

ll calc(ll val,int num,int x) 
{
    val=val*x*(k-x)+val*(num-x)*(n-k-(num-x));
    return val;
}

void dfs(int u)
{
    siz[u]=1;
    for(int i=head[u];i;i=e[i].net)
    {
        int v=e[i].v;
        if(siz[v]) continue;
        dfs(v);
        for(int x=siz[u];x>=0;x--) for(int y=siz[v];y>=0;y--)
        {
            tmp=f[u][x]+f[v][y]+calc(e[i].w,siz[v],y);
            f[u][x+y]=max(f[u][x+y],tmp);
        }
        siz[u]+=siz[v];
    }
}

int main()
{
    int x,y,z;
    n=read();k=read();
    for(int i=1;i<n;i++)
    {
        x=read();y=read();cin>>z;
        add(x,y,z);add(y,x,z);
    }
    dfs(1);printf("%lld\n",f[1][k]);
    return 0; 
}
Code

 

5.bzoj2525  https://www.lydsy.com/JudgeOnline/problem.php?id=2525

题解:https://www.cnblogs.com/L-Memory/p/9769404.html

#include<iostream>
#include<cstdio>

#define N 300007

using namespace std;
int n,m,h[N],cnt,d[N],sum,sm,mx[N],mn[N];
struct edge
{
    int ne,to;
}e[N<<1];

inline int read()
{
    int x=0,f=1;char c=getchar();
    while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

void add(int u,int v)
{ 
    e[++cnt].to=v;e[cnt].ne=h[u];h[u]=cnt;
}

void dfs(int u,int fa,int w)
{
    mx[u]=-1e9,mn[u]=1e9;
    for(int i=h[u];i;i=e[i].ne)
        if(e[i].to!=fa)
        {
            dfs(e[i].to,u,w);
            mx[u]=max(mx[u],mx[e[i].to]+1);
            mn[u]=min(mn[u],mn[e[i].to]+1);
        }
    if(d[u]&&mn[u]>w) mx[u]=max(mx[u],0);
    if(mx[u]+mn[u]<=w)mx[u]=-1e9;
    if(mx[u]==w) sm++,mx[u]=-1e9,mn[u]=0;
}

bool ok(int w)
{
    sm=0;dfs(1,0,w);
    return sm+(mx[1]>=0)<=m;
}

int main()
{
    n=read(),m=read();
    for(int i=1;i<=n;i++)
        d[i]=read(),sum+=d[i];
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        add(x,y),add(y,x);
    }
    if(sum<=m)
    {
        puts("0");return 0;
    }
    int l=0,r=n,ans=n;
    while(l<=r)
    {
        int mid=(l+r)>>1;
        if(ok(mid)) r=mid-1,ans=mid;
        else l=mid+1;
    }
    printf("%d\n",ans);
    return 0;
}
Code

 

posted @ 2017-09-03 17:36  安月冷  阅读(191)  评论(0编辑  收藏  举报