洛谷1623 树的匹配 树形动态规划 高精度
欢迎访问~原文出处——博客园-zhouzhendong
去博客园看该题解
题目传送门 - 洛谷1623
题目描述
题目描述
给一棵树,你可以匹配有边相连的两个点,问你这棵树的最大匹配时多少,并且计算出有多少种最大匹配。
输入输出格式
输入格式:
第一行一个数N,表示有多少个结点。
接下来N行,每行第一个数,表示要描述的那个结点。然后一个数m,表示这个结点有m个儿子,接下来m个数,表示它的m个儿子的编号。
【数据规模】
N≤1000,其中40%的数据答案不超过108
输出格式:
输出两行,第一行输出最大匹配数,第二行输出最大匹配方案数。
题解
我们首先考虑一下这一题的数据范围~
高精度。
(猛啊出题人)
为什么不可以人性化一点?
那么既然这样,我们在设计算法的时候,就要尽量避免复杂的运算。
我们用dp[i][0]代表在以第i个节点为根的子树里,根节点i不取,所能得到的最大匹配数;
dp[i][1]代表在以第i个节点为根的子树里,根节点i要取,所能得到的最大匹配数;
用sum[i][0]代表在以以第i个节点为根的子树里,根节点i不取,所能得到的最大匹配数的方案总数;
dp[i][1]代表在以第i个节点为根的子树里,根节点i要取,所能得到的最大匹配数的方案总数;
然后就是不断的转移。
对于dp[i][0],我们有简单的转移,根节点不取,那么每个子树随意,那么就是所有子节点s的max(dp[s][0],dp[s][1])加起来。
对于sum[i][0],我们如果一个都不选,那么显然是一种情况,然后每次选择子树的时候,我们要考虑3中情况。
对于子树s,
如果dp[s][0]>dp[s][1],那么我们会选择dp[s][0],那么sum[i][0]乘上sum[s][0]即可。
如果dp[s][0]<dp[s][1],那么也同理,sum[i][0]乘上sum[s][1]即可。
如果dp[s][0]=dp[s][1],那么,不管是s取或者不取的情况,都对当前答案有贡献,那么sum[i][j]乘上(sum[s][0]+sum[s][1])即可。
比较麻烦的是dp[i][1]和sum[i][1]。
对此,我们首先要保证有一个子节点s,dp[s][0]来和根节点匹配。
那么其他的还是和刚才一样的,选大的。
所以,我们可以求出前后缀dp以及前后缀sum,然后一个一个子节点搜过去,把答案分成3段,一段是左边,一段是右边,一段是枚举到的字节点。
如果选择该子节点的匹配数大于当然得到的dp[i][1],那么当然更新dp[i][1],并重新累加sum[i][1];
否则如果匹配数等于当前得到的dp[i][1],那么也对当前答案有贡献,累加即可。
当然高精度也是卡了很多选手的地方。
具体操作详见代码。
代码
#include <cstring> #include <algorithm> #include <cstdlib> #include <cstdio> #include <cmath> #include <vector> using namespace std; typedef long long LL; const LL mod=10000000; struct LLLL{ LL d,v[100]; LLLL(){} LLLL(int x){ d=0; memset(v,0,sizeof v); while (x){ v[++d]=x%mod; x/=mod; } } void Print(){ printf("%lld",v[d]); for (int i=d-1;i>=1;i--) printf("%07lld",v[i]); } LLLL operator +(const LLLL &x){ LLLL Ans(0); Ans.d=max(d,x.d); for (int i=1;i<=Ans.d;i++) Ans.v[i]=v[i]+x.v[i]; for (int i=1;i<=Ans.d;i++){ Ans.v[i+1]+=Ans.v[i]/mod; Ans.v[i]%=mod; } if (Ans.v[Ans.d+1]) Ans.d++; return Ans; } LLLL operator *(const LLLL &y){ LLLL x=*this,Ans(0); for (int i=1;i<=x.d;i++) for (int j=1;j<=y.d;j++) Ans.v[i+j-1]+=x.v[i]*y.v[j]; Ans.d=x.d+y.d-1; for (int i=1;i<=Ans.d;i++){ Ans.v[i+1]+=Ans.v[i]/mod; Ans.v[i]%=mod; } if (Ans.v[Ans.d+1]) Ans.d++; return Ans; } }; const int N=1000+5; vector <int> son[N]; LLLL sum[N][2],sumL[N],sumR[N]; int n,dp[N][2],maxL[N],maxR[N],in[N]; void dfs(int rt){ for (int i=0;i<son[rt].size();i++) dfs(son[rt][i]); dp[rt][0]=0; sum[rt][0]=1; for (int i=0;i<son[rt].size();i++){ int s=son[rt][i]; if (dp[s][0]>dp[s][1]){ dp[rt][0]+=dp[s][0]; sum[rt][0]=sum[rt][0]*sum[s][0]; } if (dp[s][0]<dp[s][1]){ dp[rt][0]+=dp[s][1]; sum[rt][0]=sum[rt][0]*sum[s][1]; } if (dp[s][0]==dp[s][1]){ dp[rt][0]+=dp[s][1]; sum[rt][0]=sum[rt][0]*(sum[s][0]+sum[s][1]); } } memset(sumL,0,sizeof sumL); memset(sumR,0,sizeof sumR); memset(maxL,0,sizeof maxL); memset(maxR,0,sizeof maxR); sumL[0]=1,sumR[son[rt].size()+1]=1; for (int i=1;i<=son[rt].size();i++){ int s=son[rt][i-1]; if (dp[s][0]>dp[s][1]){ maxL[i]=maxL[i-1]+dp[s][0]; sumL[i]=sumL[i-1]*sum[s][0]; } if (dp[s][0]<dp[s][1]){ maxL[i]=maxL[i-1]+dp[s][1]; sumL[i]=sumL[i-1]*sum[s][1]; } if (dp[s][0]==dp[s][1]){ maxL[i]=maxL[i-1]+dp[s][0]; sumL[i]=sumL[i-1]*(sum[s][0]+sum[s][1]); } } for (int i=son[rt].size();i>=1;i--){ int s=son[rt][i-1]; if (dp[s][0]>dp[s][1]){ maxR[i]=maxR[i+1]+dp[s][0]; sumR[i]=sumR[i+1]*sum[s][0]; } if (dp[s][0]<dp[s][1]){ maxR[i]=maxR[i+1]+dp[s][1]; sumR[i]=sumR[i+1]*sum[s][1]; } if (dp[s][0]==dp[s][1]){ maxR[i]=maxR[i+1]+dp[s][0]; sumR[i]=sumR[i+1]*(sum[s][0]+sum[s][1]); } } dp[rt][1]=0; sum[rt][1]=0; for (int i=1;i<=son[rt].size();i++){ int s=son[rt][i-1]; int v=dp[s][0]+maxL[i-1]+maxR[i+1]+1; if (v>dp[rt][1]){ dp[rt][1]=v; sum[rt][1]=sum[s][0]*sumL[i-1]*sumR[i+1]; } else if (v==dp[rt][1]) sum[rt][1]=sum[s][0]*sumL[i-1]*sumR[i+1]+sum[rt][1]; } } int main(){ freopen("tree.in","r",stdin); freopen("tree.out","w",stdout); scanf("%d",&n); for (int i=1;i<=n;i++) son[i].clear(); memset(in,0,sizeof in); for (int i=1,bh,sz;i<=n;i++){ scanf("%d%d",&bh,&sz); for (int j=1,s;j<=sz;j++) scanf("%d",&s),son[bh].push_back(s),in[s]++; } int rt=-1; for (int i=1;i<=n&&rt==-1;i++) if (in[i]==0) rt=i; dfs(rt); int ansv=max(dp[rt][0],dp[rt][1]); LLLL anssum(0); if (ansv==dp[rt][0]) anssum=anssum+sum[rt][0]; if (ansv==dp[rt][1]) anssum=anssum+sum[rt][1]; printf("%d\n",ansv); anssum.Print(); fclose(stdin);fclose(stdout); return 0; }