Codeforces 543D Road Improvement(树形DP + 乘法逆元)
题目大概说给一棵树,树的边一开始都是损坏的,要修复一些边,修复完后要满足各个点到根的路径上最多只有一条坏的边,现在以各个点为根分别求出修复边的方案数,其结果模1000000007。
不难联想到这题和HDU2196是一种类型的树形DP,因为它们都要分别求各个点的答案。然后解法也不难想:
- dp0[u]表示只考虑以u结点为根的子树的方案数
- dp1[u]表示u结点往上走,倒过来,以它父亲为根那部分的方案数
有了这两部分的结果,对于各个点u的答案就是dp0[u]*(dp1[u]+1)。这两部分求法如下,画画图比较好想:
- 首先求出dp0,这个转移是:dp0[u]=∏(dp0[v]+1)(v是u的孩子),就是对于每个孩子为根的子树的情况总数的乘积,而其中每个孩子的情况总数还要加上一个父亲到孩子之间的边不修复、孩子的子树的边全部修复的情况。
- 然后求出dp1,转移:求dp1[v],u是v的父亲,dp1[v]=dp0[u]/dp0[v]*(dp1[u]+1)。
- 现在问题来了,求dp0[u]/dp0[v],注意到结果模1000000007是一个质数,一开始我用乘法逆元WA了,因为虽然1000000007是质数,但1000000007的倍数不与1000000007互质,模1000000007结果是0,这样就出问题了!
- 本来我想改用线段树做,不过队友提醒说可以分情况讨论,如果不存在与1000000007不互质的数直接逆元搞,存在两个以上不与1000000007互质的数那结果就是0,一个的话。。。。。我就不多说了。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 using namespace std; 5 #define MAXN 222222 6 struct Edge{ 7 int v,next; 8 }edge[MAXN<<1]; 9 int NE,head[MAXN]; 10 void addEdge(int u,int v){ 11 edge[NE].v=v; edge[NE].next=head[u]; 12 head[u]=NE++; 13 } 14 long long d[2][MAXN]; 15 long long ine(long long a){ 16 long long res=1; 17 int n=1000000007-2; 18 while(n){ 19 if(n&1){ 20 res*=a; 21 res%=1000000007; 22 } 23 a*=a; 24 a%=1000000007; 25 n>>=1; 26 } 27 return res; 28 } 29 void dp0(int u,int fa){ 30 long long res=1; 31 for(int i=head[u]; i!=-1; i=edge[i].next){ 32 int v=edge[i].v; 33 if(v==fa) continue; 34 dp0(v,u); 35 res*=d[0][v]+1; 36 res%=1000000007; 37 } 38 d[0][u]=res; 39 } 40 void dp1(int u,int fa){ 41 int cnt=0; 42 long long tot=1; 43 for(int i=head[u]; i!=-1; i=edge[i].next){ 44 int v=edge[i].v; 45 if(v==fa) continue; 46 if((d[0][v]+1)%1000000007==0) ++cnt; 47 else{ 48 tot*=d[0][v]+1; 49 tot%=1000000007; 50 } 51 } 52 for(int i=head[u]; i!=-1; i=edge[i].next){ 53 int v=edge[i].v; 54 if(v==fa) continue; 55 if(cnt){ 56 if((d[0][v]+1)%1000000007==0 && cnt==1){ 57 d[1][v]=tot; 58 }else d[1][v]=0; 59 }else{ 60 d[1][v]=d[0][u]*ine((d[0][v]+1)%1000000007); 61 d[1][v]%=1000000007; 62 } 63 d[1][v]*=d[1][u]+1; 64 d[1][v]%=1000000007; 65 dp1(v,u); 66 } 67 } 68 int main(){ 69 memset(head,-1,sizeof(head)); 70 int n,a; 71 scanf("%d",&n); 72 for(int i=2; i<=n; ++i){ 73 scanf("%d",&a); 74 addEdge(a,i); 75 addEdge(i,a); 76 } 77 dp0(1,1); 78 dp1(1,1); 79 for(int i=1; i<=n; ++i){ 80 printf("%lld ",d[0][i]*(d[1][i]+1)%1000000007); 81 } 82 return 0; 83 }