树型DP

1.二叉苹果树
树上背包

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e5+101;
const int MOD=1e9+7;
const ll inf=2147483647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}

int n,q;
int tot,head[maxn],to[maxn],nx[maxn],w[maxn];
void add(int x,int y,int z){
    to[++tot]=y;nx[tot]=head[x];head[x]=tot;w[tot]=z;
}
int dp[101][101];
//dp[i][j]表示i的子树,保留了j个树枝的最大值
void dfs(int x,int fa){
    for(int i=head[x];i;i=nx[i]){
        int v=to[i];if(v==fa)continue;
        dfs(v,x);
        for(int j=q;j>=0;j--){
            for(int k=0;k<j;k++){
                dp[x][j]=max(dp[x][j],dp[x][k]+dp[v][j-k-1]+w[i]);
            }
        }
    }
    return ;
}
int main(){
    n=read();q=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read(),z=read();
        add(x,y,z);add(y,x,z);
    }
    dfs(1,1);
    cout<<dp[1][q];
    return 0;
}

选课

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e6+101;
const int MOD=1e9+7;
const ll inf=2147483647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%MOD;
        y>>=1;x=x*x%MOD;
    }
    return ans;
}

int n,m,a[maxn];
int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){
    to[++tot]=y;nx[tot]=head[x];head[x]=tot;
}
ll dp[301][301];
void dfs(int x,int fa){
    dp[x][1]=a[x];
    for(int i=head[x];i;i=nx[i]){
        int v=to[i];if(v==fa)continue;
        dfs(v,x);
        for(int j=m+1;j>=0;j--){
            if(!dp[x][j])continue;
            for(int k=0;k<=m+1;k++){
                if(!dp[v][k])continue;
                dp[x][j+k]=max(dp[x][j+k],dp[x][j]+dp[v][k]);
            }
        }
    }
    return ;
}
int main(){
    n=read();m=read();
    for(int i=1;i<=n;i++){
        int x=read(),y=read();
        a[i]=y;add(x,i);add(i,x);
    }
    a[0]=1;dfs(0,0);
    cout<<dp[0][m+1]-1<<endl;
    return 0;
}

蓝魔法师
令m为题目中的k
\(dp_{u,t}\)表示u所在联通块大小为t时,u子树的的方案数
不难想到等同于01背包,u的儿子v选和不选的问题
不同于普通01背包,不选也要进行计算
选:进行背包
不选:\(dp_{u,i}=dp_{u,i}*\sum_{k=1}^m dp_{v,k}\)
但背包过程中,虽然看起来是\(O(n^3)\)
但是有一个小技巧,用sz来优化,因为\(O(\sum sz_i^2) = O(n^2)\)

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e6+101;
const int MOD=998244353;
const ll inf=2147383647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int n,m,a[maxn],sz[maxn];
int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){
    to[++tot]=y;nx[tot]=head[x];head[x]=tot;
}
ll dp[2001][2001];
void dfs(int x,int fa){
    sz[x]=1;dp[x][1]=1;
    for(int i=head[x];i;i=nx[i]){
        int v=to[i];if(v==fa)continue;
        dfs(v,x);

        ll sum=0;
        for(int j=1;j<=min(m,sz[v]);j++)(sum+=dp[v][j])%=MOD;

        vector<ll>f(m+1);

        for(int j=min(m,sz[x]);j>=1;j--){
            for(int k=1;k<=min(m,sz[v]) && k+j<=m;k++){
                //选
                dp[x][j+k]+=dp[v][k]*dp[x][j]%MOD;
                dp[x][j+k]%=MOD;
            }
            dp[x][j]=dp[x][j]*sum%MOD;
            //不选的计算
        }
        sz[x]+=sz[v];
    }
    return ;
}
int main(){
    n=read();m=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs(1,0);
    ll ans=0;
    for(int i=1;i<=m;i++){
        ans+=dp[1][i];
        ans%=MOD;
    }
    printf("%lld\n",(ans%MOD+MOD)%MOD);
    return 0;
}

L. Perfect Matchings
题意:对于一个 2∗n 个顶点的完全图,删除给定的一颗生成树,求剩下图的完美匹配数量有多少。
完美匹配,指最大数量的边集合,集合内任意两条边都没有公共顶点。
题解:
正着想是很难的,不妨反着考虑,我们先求完全图的总共完美匹配数,然后容斥减去选择一些树边的完美匹配数

  1. 如何求出n个点完全图的完美匹配数

  1. 如何求出包含一些树边的完美匹配数
    树形背包dp
    设dp[i][j][0/1]表示i子树,匹配树边有j条,i节点不选/选的方案数

    可以枚举每个 i 子树的大小使复杂度达到 \(O(n^2)\)

最后容斥计算结果

点击查看代码
#include <bits/stdc++.h>
#define ll long long
#define int long long
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("Yes");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=3e5+101;
const int MOD=998244353;
const int inf=2147483647;
const double pi=acos(-1);
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int n;
int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){to[++tot]=y;nx[tot]=head[x];head[x]=tot;}
int dp[4001][8001][2],sz[maxn];
//dp[i][j][0/1]表示i子树,匹配树边有j条,i节点不选/选的方案数 
void dfs(int x,int fa){
	sz[x]=1;dp[x][0][0]=1;
	for(int i=head[x];i;i=nx[i]){
		int v=to[i];if(v==fa)continue;
		dfs(v,x);
		for(int j=sz[x];j>=0;j--)for(int t=0;t<=sz[v];t++){
			if(t>0){	//t=0 不进行累加,对方案数没有贡献 
				dp[x][j+t][0]+=dp[x][j][0]*(dp[v][t][0]+dp[v][t][1])%MOD;
				dp[x][j+t][0]%=MOD;
				dp[x][j+t][1]+=dp[x][j][1]*(dp[v][t][0]+dp[v][t][1])%MOD;
				dp[x][j+t][1]%=MOD;
			}
			dp[x][j+t+1][1]+=dp[x][j][0]*dp[v][t][0]%MOD;
			dp[x][j+t+1][1]%=MOD;
		}
		sz[x]+=sz[v]; 
	}
	return ;
}
signed main(){
	n=read();
	for(int i=1;i<2*n;i++){
		int x=read(),y=read();
		add(x,y);add(y,x);
	}
	dfs(1,1);
	vector<int>p(n+1);p[0]=1;
	for(int i=1;i<=n;i++)p[i]=p[i-1]*(2*i-1)%MOD;
	int ans=0;
	for(int i=0;i<=n;i++){
		if(i&1)ans=ans-(dp[1][i][0]+dp[1][i][1])*p[n-i]%MOD;
		else ans+=(dp[1][i][0]+dp[1][i][1])*p[n-i]%MOD;
		ans%=MOD;
	}
	cout<<(ans%MOD+MOD)%MOD;
    return 0;
}


2. [USACO 2008 Jan G]Cell Phone Network
经典问题,树的最小支配集
因为一个点有三种可能:1.当前点被选择 2.被儿子支配 3.被父亲支配
因此设
\(dp_{i,0}\)表示选择i节点时,i子树全部被支配的最小值
\(dp_{i,1}\)表示不选择i节点时,i子树全部被支配的最小值
\(dp_{i,2}\)表示不选择i节点时,i子树除了i节点全部被支配的最小值
很容易写出
\(dp_{i,0}=\sum_v min(dp_{v,0},dp_{v,1},dp_{v,2})\)
\(dp_{i,2}=\sum_v dp[v][1]\)
\(dp_{i,1}\)需要贪心选择,因为必须保证i的儿子中至少有一个被选中
若存在至少一个\(dp_{v,0}<dp_{v,1}\),则\(dp_{i,1}=\sum_v min(dp_{v,0},dp_{v,1})\)
否则,我们必须找一个儿子使得\(dp_{v,0}-dp_{v,1}\)最小的来加入到\(dp_{i,1}\)中,具体见代码

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e6+101;
const int MOD=998244353;
const ll inf=2147483647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int n;
int tot,head[maxn],to[maxn],nx[maxn];
void add(int x,int y){
    to[++tot]=y;nx[tot]=head[x];head[x]=tot;
}
ll dp[maxn][3];
void dfs(int x,int fa){
    dp[x][0]=1;
    ll delt=1001;  //最大值,设为最大点数+1
    for(int i=head[x];i;i=nx[i]){
        int v=to[i];if(v==fa)continue;
        dfs(v,x);
        dp[x][0]+=min(dp[v][0],min(dp[v][1],dp[v][2]));
        dp[x][1]+=min(dp[v][0],dp[v][1]);
        dp[x][2]+=dp[v][1];
        delt=min(delt,dp[v][0]-dp[v][1]);
        /*
        等同于上一句
        if(dp[v][0]<dp[v][1])delt=0;
        else delt=min(dp[v][0]-dp[v][1],delt);
        */
    }
    dp[x][1]+=max(0ll,delt);
    return ;
}
int main(){
    n=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs(1,1);
    cout<<min(dp[1][0],dp[1][1])<<endl;
    return 0;
}

3.黑白树
对于每次操作必须选择白色节点,可以不用考虑,因为每次选择可以通过交换操作顺序,使得之前选择黑色节点变成选择白色节点
首先肯定选择叶子结点
一个贪心思路就是,能不选就不选,一个节点能往上延伸多少,在其之后再选
但会出现问题,比如下图,贪心操作会选择5,2,1节点,实则选择5,4节点更优

但也会发现,也许只需要维护,当前儿子子树内的最长向上延伸长度即可
\(dp_u\)表示u节点的子树(不包含u节点)能向上最长延伸多少
\(k_u\)表示当前节点能向上最长延伸多少
整个贪心思路就是,能不选就不选,必须要选的话,就从子树中找一个向上延伸最长的点

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e6+101;
const int MOD=998244353;
const ll inf=2147483647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
int n,k[maxn];
int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){
    to[++tot]=y;nx[tot]=head[x];head[x]=tot;
}
int dp[maxn],ans;
void dfs(int x,int fa){
    for(int i=head[x];i;i=nx[i]){
        int v=to[i];if(v==fa)continue;
        dfs(v,x);
        dp[x]=max(dp[x],dp[v]-1);
        k[x]=max(k[x],k[v]-1);
    }
    if(dp[x]==0){
        ans++;
        dp[x]=k[x];
    }
    return ;
}
int main(){
    n=read();
    for(int i=2;i<=n;i++){
        int x=read();
        add(i,x);add(x,i);
    }
    for(int i=1;i<=n;i++)k[i]=read();
    dfs(1,1);
    cout<<ans;
    return 0;
}

4.Tree
换根dp
先以1为根进行dp
\(dp_i\)表示i子树的联通集个数
\(dp_u=\prod_v (dp_v+1)\) (+1是包含空集)
然后进行换根dp
其中一个坑点见代码

点击查看代码
#include<functional>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<string>
#include<cstdio>
#include<vector>
#include<cmath>
#include<queue>
#include<deque>
#include<stack>
#include<map>
#include<set>
#define ll long long 
#define pa pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define YES {puts("YES");return;}
#define NO {puts("NO");return ;}
using namespace std;
const int maxn=2e6+101;
const int MOD=1e9+7;
const ll inf=2147483647;
const double eps=1e-12;

ll read(){
    ll x=0,f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
    return x*f;
}
ll power(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)ans=ans*x%MOD;
        y>>=1;x=x*x%MOD;
    }
    return ans;
}

int n;
int tot,head[maxn],nx[maxn],to[maxn];
void add(int x,int y){
    to[++tot]=y;nx[tot]=head[x];head[x]=tot;
}
ll dp[maxn],ans[maxn];
void dfs(int x,int fa){
    dp[x]=1;
    for(int i=head[x];i;i=nx[i]){
        int v=to[i];if(v==fa)continue;
        dfs(v,x);
        dp[x]=dp[x]*(dp[v]+1)%MOD;
    }
    return ;
}
void get_dp(int x,int fa){
    ans[x]=dp[x]*(dp[fa]+1)%MOD;
    ll now=dp[x];
    for(int i=head[x];i;i=nx[i]){
        int v=to[i];if(v==fa)continue;
        if((dp[v]+1)%MOD==0){
            //可能(dp[v]+1)%MOD==0
            //导致ans[x]/0出现问题,暴力计算
            dp[x]=1;
            for(int j=head[x];j;j=nx[j]){
                if(to[j]==v)continue;
                dp[x]=dp[x]*(dp[to[j]]+1)%MOD;
            }
        }
        else dp[x]=ans[x]*power((dp[v]+1),MOD-2)%MOD;
        get_dp(v,x);
    }
    dp[x]=now;  //上述代码会改变dp[x],重新改回
    return ;
}
int main(){
    n=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(y,x);add(x,y);
    }
    dfs(1,0);get_dp(1,0);
    for(int i=1;i<=n;i++)printf("%lld\n",(ans[i]%MOD+MOD)%MOD);
    return 0;
}

5.划分树
题解

posted @ 2022-08-03 15:38  I_N_V  阅读(18)  评论(0编辑  收藏  举报