BZOJ 3992: [SDOI2015]序列统计
以前觉得FFT和NTT这种东西考试不会考,就没去学。
真香
列个dp的方程,发现每层是求$C_k=\sum_{i*j=k}A_i*B_j$
这个题的$M$是个质数,我们考虑可以用指标的方式表示$i$和$j$,这样$i*j$就变成了$i+j$了。
然后就是最裸的NTT了,套个快速幂就好了。
1 #include <bits/stdc++.h> 2 using namespace std; 3 #define MOD 1004535809 4 inline int Power(int x, int y, int P) { 5 int ret = 1; 6 while(y) { 7 if(y & 1) ret = 1ll * ret * x % P; 8 x = 1ll * x * x % P; y >>= 1; 9 } 10 return ret; 11 } 12 int cnt[10010]; 13 int n, m, x, S; 14 inline int calc() { 15 if(m == 2) return 1; 16 for(int i = 2; ; ++ i) { 17 bool flag = false; 18 for(int j = 2; j * j < m; ++ j) if((m - 1) % j == 0) { 19 if(Power(i, (m - 1) / j, m) == 1) { 20 flag = true; 21 break; 22 } 23 } 24 if(flag) continue; 25 return i; 26 } 27 } 28 int f[32010], g[32010]; 29 int a[32010], b[32010]; 30 int rev[32010], w[32010]; 31 int N; 32 inline void NTT(int *a) { 33 for(int i = 0; i < N; ++ i) { 34 if(rev[i] > i) { 35 swap(a[rev[i]], a[i]); 36 } 37 } 38 for(int d = 1, t = (N >> 1); d < N; d <<= 1, t >>= 1) { 39 for(int i = 0; i < N; i += (d << 1)) { 40 for(int j = 0; j < d; ++ j) { 41 int tmp = 1ll * w[t * j] * a[i + j + d] % MOD; 42 a[i + j + d] = (a[i + j] - tmp + MOD) % MOD; 43 a[i + j] = (a[i + j] + tmp) % MOD; 44 } 45 } 46 } 47 } 48 inline void mul(int *g, int *f) { 49 w[0] = 1; w[1] = Power(3, (MOD - 1) / N, MOD); 50 for(int i = 2; i < N; ++ i) { 51 w[i] = 1ll * w[i - 1] * w[1] % MOD; 52 } 53 for(int i = 0; i < N; ++ i) { 54 a[i] = g[i]; 55 b[i] = f[i]; 56 } 57 NTT(a); NTT(b); 58 for(int i = 0; i < N; ++ i) { 59 a[i] = 1ll * a[i] * b[i] % MOD; 60 } 61 w[0] = 1; w[1] = Power(w[1], MOD - 2, MOD); 62 for(int i = 2; i < N; ++ i) { 63 w[i] = 1ll * w[i - 1] * w[1] % MOD; 64 } 65 NTT(a); 66 int inv = Power(N, MOD - 2, MOD); 67 for(int i = 0; i < N; ++ i) { 68 a[i] = 1ll * a[i] * inv % MOD; 69 } 70 for(int i = 0; i < m - 1; ++ i) { 71 g[i] = (a[i] + a[i + m - 1]) % MOD; 72 } 73 } 74 int main() { 75 scanf("%d%d%d%d", &n, &m, &x, &S); 76 for(int i = 1; i <= S; ++ i) { 77 int x; 78 scanf("%d", &x); 79 cnt[x] = 1; 80 } 81 int G = calc(), pos = -1; 82 for(int i = 1, j = 0; j < m - 1; ++ j, i = (i * G) % m) { 83 if(cnt[i]) f[j] = 1; 84 if(i == x) pos = j; 85 } 86 N = 1; int L = 0; 87 for(; N <= 2 * (m - 1); N <<= 1, ++ L); 88 for(int i = 0; i < N; ++ i) { 89 rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1)); 90 } 91 g[0] = 1; 92 while(n) { 93 if(n & 1) mul(g, f); 94 mul(f, f); n >>= 1; 95 } 96 if(pos != -1) printf("%d\n", g[pos]); 97 else puts("0"); 98 }