[周末训练]加分二叉树
题目
【题目描述】
原题来自:NOIP 2003
设一个$N$个节点的二叉树$tree$的中序遍历为$(1,2,3,......,N)$ ,其中数字$1,2,3,......,N$为节点编号。每个节点都有一个分数(均为正整数),记第$i$个节点的分数为$d_i$,及它的每个子树都有一个加分,任一棵子树$subtree$(也包含$tree$本身)的加分计算方法如下:
记$subtree$的左子树加分为$l$,右子树加分为$r$,$subtree$的根的分数为$a$,则$subtree$的加分为:$$l×r+a$$
若某个子树为空,规定其加分为$1$,叶子的加分就是叶节点本身的分数。不考虑它的空子树。
试求一棵符合中序遍历为$1,2,3,......,N$且加分最高的二叉树$tree$。
要求输出:
- $tree$的最高加分
- $tree$的前序遍历
【输入格式】
第一行一个整数$N$表示节点个数;
第二行$N$个空格隔开的整数,表示各节点的分数$d_i$。
【输出格式】
第一行一个整数,为最高加分$b$;
第二行$N$个用空格隔开的整数,为该树的前序遍历。
【样例】
样例输入
5 5 7 1 2 10
样例输出
145 3 1 2 4 5
【数据范围与提示】
对于$100%$的数据,$N<30$,$b<100$,结果不超过$4×10^9$。
题解
【做题经历】
以为是一道树$dp$,但是做了半天都没啥思路,果断弃疗
但是不是树$dp$是什么呢,然后想了一下,发现题目给定这个树的中序遍历是$1,2,3,......,N$,但是也只有一个树的中序遍历,也就是说我们并不知道这个树的形状。
那么对于一个中序遍历片段$l,l+1,l+2,...,r$,假设它是一棵子树(因为中序遍历我们并不能确定树的形状,随便抓一段出来都可以当子树)
我们并不知道到底哪一个数是这个子树的根节点,就需要$k$来枚举根节点,那么$l≤k≤r$
那么为了保存状态,定义状态$dp[l][r]$:编号从$l$到$r$为一棵子树时,它的最高加分
为了保存这个子树的根(因为最后还是输出先序遍历结果),再开一个数组$rt[i][j]$:当子树$[l,r]$取最大加分时,它的根的编号
那么就有状转$$dp[i][j]=max\left \{dp[i][k-1]×dp[k+1][r]+d[k]|k \in [l,r]\right \}$$
【正解】
一道水题我怎么不能手切,去看看我的亲身经历吧
代码请放心食用
#include<bits/stdc++.h> using namespace std; template<class T>inline void qread(T& x){ char c;x=0;bool flg=false; while((c=getchar())<'0'||'9'<c)if(c=='-')flg=true; for(x=(c^48);'0'<=(c=getchar())&&c<='9';x=(x<<1)+(x<<3)+(c^48)); if(flg)x=-x; } template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);} inline int rqread(){ char c;int x=0,f=1; while((c=getchar())<'0'||'9'<c)if(c=='-')f=-1; for(x=(c^48);'0'<=(c=getchar())&&c<='9';x=(x<<1)+(x<<3)+(c^48)); return x*f; } const int MAXN=30; int N,dp[MAXN+5][MAXN+5],rt[MAXN+5][MAXN+5]; int make(const int l,const int r){ if(l>r)return 1; if(dp[l][r]!=-1)return dp[l][r]; int now; for(int k=l;k<=r;++k){ now=make(l,k-1)*make(k+1,r)+dp[k][k]; if(now>dp[l][r]){ dp[l][r]=now; rt[l][r]=k; } } return dp[l][r]; } bool flg=false; void output(const int l,const int r){ if(l>r)return; if(flg)putchar(' '); else flg=true; printf("%d",rt[l][r]); output(l,rt[l][r]-1); output(rt[l][r]+1,r); } signed main(){ qread(N); memset(dp,-1,sizeof dp); for(int i=1;i<=N;++i)qread(dp[i][i]),rt[i][i]=i; printf("%d\n",make(1,N)); output(1,N),putchar('\n'); return 0;
}