关于树上DP的转移方式与复杂度证明

以后要勤写总结了唔

这种优化针对于转移的状态数与子树规模有关的柿子。

例如对于n个树型依赖物品的树上背包dp,每个节点是一个物品且大小为1,设待转移结点u则u的背包容量不会超过u的子树规模,转移子节点v占据的容积不会超过v的子树规模。所以我们有以下转移方式:

  • 最脑残的,直接开siz[u]个固定背包容量(第一层for ∑siz[v]),用儿子来”填“(第二层for siz[v])。显然在转移非最后一个儿子时有大量无用状态。
  • 较脑小的(是我 并不),开∑已扫过儿子和当前儿子的siz(第一层for ∑已扫过儿子和当前儿子的siz「要为该儿子‘填’留出空间」),依然是用儿子”填“待转移状态(第二层for siz[v])。然而这种转移复杂度依然不乐观,接近于n^2(不会证,但的确被总复杂度n^2的数据卡掉了)再加上dfs n^3,不够优秀。
  • 卓越算法,既然”填“不行,我们可以”推“(因为有效的状态比无效状态少)。开∑已扫过儿子的siz(第一层for 都可能对答案造成贡献),这次我们用儿子和左边”推“状态,类似dp[u][k+j]的形式。对于一些题面较为复杂的限制条件,我们可以开个temp数组(类似于temp[0/1滚动儿子][0/1限制条件 状态1][状态2]...)完成以上转移,然后再把temp塞到dp数组里。这种转移依然有三层for,但确有O(n^2)的优秀复杂度。以下给出证明:

首先脑补一下转移的图(其实是我懒得画了),对于当前转移的v子树(之所以说子树是因为最多转移它的siz)与之前已考虑的子子树中所有点构成点对且点对的lca为u,当完成对所有点的dfs,所有点作为lca的情况都被考虑且保证点对不重复(点对的lca唯一),将构成n^2个点对。得证。

 1     F(i,2,son[u][0])
 2     {
 3         int v=son[u][i];
 4         F(j,0,deep[u])
 5         {
 6             F(k,0,deep[son[u][i]])
 7             {
 8                 temp[i&1][0][max(j,k+1)]+=temp[i&1^1][0][j]*f[v][k]%p;       temp[i&1][0][max(j,k+1)]%=p;
 9                 temp[i&1][1][max(j,k)]+=temp[i&1^1][0][j]*f[v][k]%p;         temp[i&1][1][max(j,k)]%=p;
10                 temp[i&1][1][max(j,k+1)]+=temp[i&1^1][1][j]*f[v][k]%p;        temp[i&1][1][max(j,k+1)]%=p;
11             }
12         }
13         F(j,0,deep[u]) temp[i&1^1][0][j]=temp[i&1^1][1][j]=0;
14         deep[u]=max(deep[u],deep[son[u][i]]+1);
15     }
16     F(j,0,deep[u]) f[u][j]=temp[son[u][0]&1][1][j];
大概长这样

 

相关题目(在本家OJ NOIP模拟3里)

树上染色

点与点之间不好推,平时应多注意边的贡献。对于一条边,它的贡献为$一边的黑点数*另一边的黑点数*边权+一边的白点数*另一边的白点数*边权$,我们可以定义状态dp[u][j]为u子树中选出j个黑点的最大收益,简化下,即在u阶段下在$|son[u]|$个分组中每个至多选一个(可以不选)状态权值为dp[v][k]重量为k使背包dp[u][j]收益最大,然后就可以在树上进行背包DP。

观察$N<=2000$,那么我们就需要top所说的$O(N^2)$转移了。由于这题dp柿子比较简单,用不用temp无所谓,只要“推”状态就可以达到。

 1 #include<cstdio>
 2 #include<vector>
 3 #include<set>
 4 #include<cmath>
 5 #include<cstring>
 6 #include<algorithm>
 7 #define MAXN 2005
 8 #define ll long long
 9 #define reg register
10 #define F(i,a,b) for(register int (i)=(a);(i)<=(b);++(i))
11 using namespace std;
12 inline int read();
13 struct R{
14     int u,v,next;
15     ll w;
16 }r[MAXN<<1];
17 int n,black_tot;
18 int fir[MAXN],o,siz[MAXN];
19 ll dp[MAXN][MAXN];
20 void add(int u,int v,ll w)
21 {
22     r[++o].u=u;
23     r[o].v=v;
24     r[o].w=w;
25     r[o].next=fir[u];
26     fir[u]=o;
27 }
28 void dfs(int u,int fa)
29 {
30     siz[u]=1;
31     for(reg int i=fir[u];i;i=r[i].next)
32     {
33         int v=r[i].v;
34         if(v==fa) continue;
35         dfs(v,u);
36         for(reg int j=min(siz[u],black_tot);j>=0;--j)
37             for(reg int k=min(siz[v],black_tot-j);k>=0;--k)
38                 dp[u][j+k]=max(dp[u][j+k],dp[u][j]+dp[v][k]+1ll*k*(black_tot-k)*r[i].w+1ll*(siz[v]-k)*(n-siz[v]-black_tot+k)*r[i].w);
39         siz[u]+=siz[v];
40     }
41 }
42 int main()
43 {
44     n=read(); black_tot=read();
45     int a,b;
46     ll t;
47     F(i,1,n-1)
48     {
49         a=read(); b=read(); scanf("%lld",&t);
50         add(a,b,t); add(b,a,t);
51     }
52     dfs(1,0);
53     printf("%lld",dp[1][black_tot]);
54     return 0;
55 }
56 inline int read()
57 {
58     int x=0;
59     char tc=getchar();
60     while(tc<'0'||tc>'9') tc=getchar();
61     while(tc>='0'&&tc<='9') x=x*10+tc-48,tc=getchar();
62     return x;
63 }
不用temp
 1 #include<cstdio>
 2 #include<vector>
 3 #include<set>
 4 #include<cmath>
 5 #include<cstring>
 6 #include<algorithm>
 7 #define MAXN 2005
 8 #define ll long long
 9 #define inf (1e9)+1
10 #define reg register
11 #define F(i,a,b) for(register int (i)=(a);(i)<=(b);++(i))
12 using namespace std;
13 inline int read();
14 struct R{
15     int u,v,w,next;
16 }r[MAXN<<1];
17 int n,black_tot;
18 int fir[MAXN],o;
19 int siz[MAXN];
20 ll dp[MAXN][MAXN],temp[2][MAXN];
21 void add(int u,int v,int w)
22 {
23     r[++o].u=u;
24     r[o].v=v;
25     r[o].w=w;
26     r[o].next=fir[u];
27     fir[u]=o;
28 }
29 void dfs(int u,int fa)
30 {
31     siz[u]=1;
32     for(reg int i=fir[u];i;i=r[i].next)
33     {
34         int v=r[i].v;
35         if(v==fa) continue;
36         dfs(v,u);
37     }
38     memset(temp,0,sizeof(temp));
39     int cur=0;
40     for(reg int i=fir[u];i;i=r[i].next)
41     {
42         int v=r[i].v;
43         if(v==fa) continue;
44         for(reg int j=min(siz[u],black_tot);j>=0;--j)
45         {
46             for(reg int k=0;k<=min(siz[v],black_tot-j);++k)
47         //    for(reg int k=min(siz[v],black_tot-j);k>=0;--k)
48             {
49                 temp[cur^1][j+k]=max(temp[cur^1][j+k],temp[cur][j]+dp[v][k]+1ll*k*(black_tot-k)*r[i].w+1ll*(siz[v]-k)*(n-siz[v]-black_tot+k)*r[i].w);
50             }
51         }
52         siz[u]+=siz[v];
53         for(reg int j=0;j<=siz[u];++j) temp[cur][j]=0;
54         cur^=1;
55     }
56     for(reg int j=0;j<=siz[u];++j) dp[u][j]=temp[cur][j];
57 }
58 int main()
59 {
60 //    freopen("data.in","r",stdin);
61 //    freopen("data.out","w",stdout);
62     n=read(); black_tot=read();
63     int a,b;
64     ll t;
65     F(i,1,n-1)
66     {
67         a=read(); b=read(); scanf("%lld",&t);
68         add(a,b,t); add(b,a,t);
69     }
70     dfs(1,0);
71     printf("%lld\n",dp[1][black_tot]);
72     return 0;
73 }
74 inline int read()
75 {
76     int x=0;
77     char tc=getchar();
78     while(tc<'0'||tc>'9') tc=getchar();
79     while(tc>='0'&&tc<='9') x=x*10+tc-48,tc=getchar();
80     return x;
81 }
temp

 

可怜与超市

搜题解有惊喜(滑稽)

优惠券之间有限制,$1<xi<i$?联想下随机树的生成方法(蓝皮书后),不难发现限制关系是树形的。

我们把$xi->i$连边建树,那么对于一个节点v能使用优惠券,当且仅当它的父亲节点u能使用优惠券且购买。

这题稍有一点反套路在于b很大,不能做常规的背包dp定义第二维为背包容量。

那么我们无妨改变dp数组含义定义$dp[u][j][0/1]$为子树u中购买j商品用(1)或不用(0)卷的最小花费,显然当值大于b时不再能有贡献。

然后就可以写出dp式子:

dp[u][j][0]=min(dp[u][j-k][0]+dp[v][k][0])

dp[u][j][1]=min(dp[u][j-k][1]+min(dp[v][k][0]+dp[v][k][1]))  这式子一开始我还以为是错的e

然后你就会T70。这就是上文所述的第二种转移。

我们把它转化为推的形式,且用temp数组维护。

temp[cur^1][0][j+k]=min(temp[cur^1][0][j+k],min(temp[cur][0][j]+dp[v][k][0],temp[cur][0][j+k]));

if(j)temp[cur^1][1][j+k]=min(temp[cur^1][1][j+k],min(temp[cur][1][j]+min(dp[v][k][0],dp[v][k][1]),temp[cur][1][j+k]));

滚动维护前缀子树temp和当前子树v更新。

siz[u]的更新一定一定要放在后边,不然就会退化成$O(n^3)$

 

posted @ 2019-07-17 16:42  hzoi_yzh  阅读(233)  评论(0编辑  收藏  举报