Codeforces997D Cycles in product 【FFT】【树形DP】
题目大意:
给两个树,求环的个数。
题目分析:
出题人摆错题号系列。
通过画图很容易就能想到把新图拆在两个树上,在树上游走成环。
考虑DP状态F,G,T。F表示最终答案,T表示儿子不考虑父亲,G表示父亲不考虑儿子。T通过从下往上做NTT,G通过从上往下做NTT。F顺便做NTT。
最后做一下拼接就行。
代码:
1 #include<bits/stdc++.h> 2 using namespace std; 3 4 const int maxn = 4020; 5 const int mod = 998244353; 6 const int gg = 3; 7 8 int n[2],k; 9 10 vector <int> g[2][maxn]; 11 12 int f[2][maxn][80],gi[2][maxn][80],T[2][maxn][80]; 13 14 int C[100][100]; 15 16 int fast_pow(int now,int pw){ 17 int ans = 1,bit = 1,dt = now; 18 while(bit <= pw){ 19 if(bit & pw){ans = (1ll*ans*dt)%mod;} 20 bit <<=1; dt = (1ll*dt*dt)%mod; 21 } 22 return ans; 23 } 24 25 void read(){ 26 scanf("%d%d%d",&n[0],&n[1],&k); 27 for(int i=1;i<n[0];i++){ 28 int u,v; scanf("%d%d",&u,&v); 29 g[0][u].push_back(v); g[0][v].push_back(u); 30 } 31 for(int i=1;i<n[1];i++){ 32 int u,v; scanf("%d%d",&u,&v); 33 g[1][u].push_back(v); g[1][v].push_back(u); 34 } 35 } 36 37 int ord[260]; 38 39 void NTT(int *d,int len,int dr){ 40 for(int i=0;i<len;i++) if(ord[i] < i) swap(d[i],d[ord[i]]); 41 for(int i=1;i<len;i<<=1){ 42 int wn = fast_pow(gg,(mod-1)/(2*i)); 43 if(dr == -1) wn = fast_pow(wn,mod-2); 44 for(int j=0;j<len;j+=(i<<1)){ 45 for(int k=0,w=1;k<i;k++,w = (1ll*w*wn)%mod){ 46 int x = d[j+k],y = (1ll*w*d[j+k+i])%mod; 47 d[j+k] = (x+y)%mod; 48 d[j+k+i] = (x-y+mod)%mod; 49 } 50 } 51 } 52 if(dr == -1){ 53 int iv = fast_pow(len,mod-2); 54 for(int i=0;i<len;i++) d[i] = (1ll*d[i]*iv)%mod; 55 } 56 } 57 58 int A[260],B[260]; 59 int fi[260],A0[260]; 60 61 void INV(){ 62 int len = 1,bit = 0; while(len <= k) len<<=1,bit++; 63 memset(A0,0,sizeof(A0));memset(fi,0,sizeof(fi)); 64 A0[0] = 1; 65 for(int i=2,j=1;i<=len;i<<=1,j++){ 66 for(int k=0;k<i;k++) fi[k] = A[k]; 67 int rl = i*2,rb = j+1; 68 for(int k=0;k<rl;k++) ord[k] = (ord[k>>1]>>1) + ((k&1)<<rb-1); 69 NTT(A0,rl,1); NTT(fi,rl,1); 70 for(int k=0;k<rl;k++){ 71 A0[k] = (2*A0[k]-(1ll*fi[k]*A0[k]%mod)*A0[k]%mod)%mod; 72 if(A0[k] < 0) A0[k] += mod; 73 } 74 NTT(A0,rl,-1); 75 for(int k=i;k<rl;k++) A0[k] = fi[k] = 0; 76 } 77 for(int i=0;i<=k;i++) A[i] = A0[i]; 78 } 79 80 void dfs1(int kd,int now,int fa){ 81 for(auto it:g[kd][now]){ 82 if(it == fa) continue; 83 dfs1(kd,it,now); 84 } 85 memset(A,0,sizeof(A)); 86 for(auto it:g[kd][now]){ 87 if(it == fa) continue; 88 for(int i=0;i<=k-2;i+=2) A[i+2] = (A[i+2]+T[kd][it][i])%mod; 89 } 90 A[0] -= 1; if(A[0] < 0) A[0] += mod; 91 for(int i=0;i<=k;i++){ A[i] *= -1; if(A[i] < 0) A[i] += mod;} 92 INV(); 93 for(int i=0;i<=k;i++) T[kd][now][i] = A[i]; 94 } 95 96 void dfs2(int kd,int now,int fa){ 97 memset(B,0,sizeof(B)); 98 for(auto it:g[kd][now]){ 99 if(it == fa) continue; 100 for(int i=0;i<=k-2;i+=2) B[i+2] = (B[i+2]+T[kd][it][i])%mod; 101 } 102 for(int i=0;i<=k-2;i+=2) B[i+2] = (B[i+2]+gi[kd][now][i])%mod; 103 for(auto it:g[kd][now]){ 104 if(it == fa) continue; 105 for(int i=0;i<=k-2;i+=2) B[i+2] = (B[i+2]+mod-T[kd][it][i])%mod; 106 memset(A,0,sizeof(A)); 107 for(int i=0;i<=k;i++) A[i] = (mod-B[i])%mod; A[0] = (1-A[0]+mod)%mod; 108 INV(); for(int i=0;i<=k;i++) gi[kd][it][i] = A[i]; 109 for(int i=0;i<=k-2;i+=2) B[i+2] = (B[i+2]+T[kd][it][i])%mod; 110 } 111 memset(A,0,sizeof(A)); 112 for(int i=0;i<=k;i++) A[i] = (mod-B[i])%mod; A[0] = (1-A[0]+mod)%mod; 113 INV(); for(int i=0;i<=k;i++) f[kd][now][i] = A[i]; 114 for(auto it:g[kd][now]){ 115 if(it == fa) continue; 116 dfs2(kd,it,now); 117 } 118 } 119 120 void solve(int kd){ 121 dfs1(kd,1,0); 122 dfs2(kd,1,0); 123 } 124 125 void work(){ 126 solve(0); 127 solve(1); 128 for(int i=1;i<=k;i++){ 129 C[i][0] = C[i][i] = 1; 130 for(int j=1;j<i;j++) C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod; 131 } 132 int ans = 0; 133 for(int i=0;i<=k;i++){ 134 int s1 = 0,s2 = 0; 135 for(int j=1;j<=n[0];j++) s1 += f[0][j][i],s1 %= mod; 136 for(int j=1;j<=n[1];j++) s2 += f[1][j][k-i],s2 %= mod; 137 int pp = (1ll*s1*s2)%mod;pp = (1ll*C[k][i]*pp)%mod; 138 ans += pp; ans %= mod; 139 } 140 printf("%d",ans); 141 } 142 143 int main(){ 144 read(); 145 work(); 146 return 0; 147 }