$NTT$(快速数论变换)
- 概念引入
- 阶
对于$p \in N_+$且$(a, \ p) = 1$,满足$a^r \equiv 1 (mod \ p)$的最小的非负$r$为$a$模$p$意义下的阶,记作$\delta_p(a)$
- 原根
定义:若$p \in N_+$且$a \in N$,若$\delta_p(a) = \phi(p)$,则称$a$为模$p$的一个原根
相关定理:
- 若一个数$m$拥有原根,那么它必定为$2, \ 4, \ p^t, \ 2p^t \ (p$为奇质数$)$的其中一个
- 每个数$p$都有$\phi(\phi(p))$个原根
证明:若$p \in N_+$且$(a, \ p) = 1$,正整数$r$满足$a^r \equiv 1 (mod \ p)$,那么$\delta(p) | r$,由此推广,可知$\delta(p) | \phi(p)$,所以$p$的原根个数即为$p$之前与$\phi(p)$互质的数,即$\phi(p)$故定理成立
- 若$g$是$m$的一个原根,则$g, \ g^1, \ g^2, \ ..., \ g^{\phi(m)} (mod \ p)$两两不同
原根求法:
将$\phi(m)$质因数分解,得$\phi(m) = p_1^{c_1} * p_2^{c_2} * ... * p_k^{c_k}$
那么所有$g$满足$g^{\frac{\phi(m)}{p_i}} \neq 1 (mod \ m)$即为$m$的原根
- $NTT$
由于$FTT$涉及到复数的运算,所以常数很大,而$NTT$仅需使用长整型,可大大优化常数
能够将原根代替单位根进行计算,是因为它们的性质相似,至少在单位根需要的那几个性质原根都满足,当然,要能够进行$NTT$,需要满足模数$p$为质数,且$p = ax + 1$其中$x$为$2$的次幂,那么一般能满足条件的数(常用)有:
$|\ \ \ \ \ \ \ \ \ \ \ \ p \ \ \ \ \ \ \ \ \ \ \ \ |\ \ \ \ g \ \ \ \ |$
$|\ \ \ \ 469762049 \ \ \ \ |\ \ \ \ 3 \ \ \ \ |$
$|\ \ \ \ 998244353 \ \ \ \ |\ \ \ \ 3 \ \ \ \ |$
$|\ \ \ 1004535809 \ \ \ |\ \ \ \ 3 \ \ \ \ |$
那么,就可以将单位根$\omega_n$替换为$g^{\frac{p - 1}{n}}$进行$NTT$了
- 代码
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 #include <cmath> 6 7 #define MOD 998244353 8 #define g 3 9 10 using namespace std; 11 12 typedef long long LL; 13 14 const int MAXN = (1 << 22); 15 16 LL power (LL x, int p) { 17 LL cnt = 1; 18 while (p) { 19 if (p & 1) 20 cnt = cnt * x % MOD; 21 22 x = x * x % MOD; 23 p >>= 1; 24 } 25 26 return cnt; 27 } 28 29 const LL invg = power (g, MOD - 2); 30 31 int N, M; 32 LL A[MAXN], B[MAXN]; 33 34 int oppo[MAXN]; 35 int limit; 36 void NTT (LL* a, int inv) { 37 for (int i = 0; i < limit; i ++) 38 if (i < oppo[i]) 39 swap (a[i], a[oppo[i]]); 40 for (int mid = 1; mid < limit; mid <<= 1) { 41 LL ome = power (inv == 1 ? g : invg, (MOD - 1) / (mid << 1)); 42 for (int n = mid << 1, j = 0; j < limit; j += n) { 43 LL x = 1; 44 for (int k = 0; k < mid; k ++, x = x * ome % MOD) { 45 LL a1 = a[j + k], xa2 = x * a[j + k + mid] % MOD; 46 a[j + k] = (a1 + xa2) % MOD; 47 a[j + k + mid] = (a1 - xa2 + MOD) % MOD; 48 } 49 } 50 } 51 } 52 53 int getnum () { 54 int num = 0; 55 char ch = getchar (); 56 57 while (! isdigit (ch)) 58 ch = getchar (); 59 while (isdigit (ch)) 60 num = (num << 3) + (num << 1) + ch - '0', ch = getchar (); 61 62 return num; 63 } 64 65 int main () { 66 N = getnum (), M = getnum (); 67 for (int i = 0; i <= N; i ++) 68 A[i] = (int) getnum (); 69 for (int i = 0; i <= M; i ++) 70 B[i] = (int) getnum (); 71 72 int n, lim = 0; 73 for (n = 1; n <= N + M; n <<= 1, lim ++); 74 for (int i = 0; i <= n; i ++) 75 oppo[i] = (oppo[i >> 1] >> 1) | ((i & 1) << (lim - 1)); 76 limit = n; 77 NTT (A, 1); 78 NTT (B, 1); 79 for (int i = 0; i <= n; i ++) 80 A[i] = A[i] * B[i] % MOD; 81 NTT (A, - 1); 82 LL invn = power (n, MOD - 2); 83 for (int i = 0; i <= N + M; i ++) { 84 if (i) 85 putchar (' '); 86 printf ("%d", (int) (A[i] * invn % MOD)); 87 } 88 puts (""); 89 90 return 0; 91 } 92 93 /* 94 1 2 95 1 2 96 1 2 1 97 */ 98 99 /* 100 5 5 101 1 7 4 0 9 4 102 8 8 2 4 5 5 103 */
- 任意模数$NTT$(三模数$NTT$法)
有公式
直接乘会爆$long \ long$,就先将上面的用$CRT$合并,得
那么设$Ans = kM + A$,则有
直接处理即可
- 代码(任意模数$NTT$)
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 5 using namespace std; 6 7 typedef long long LL; 8 9 const int MAXN = (1 << 20); 10 11 const LL MOD[3]= {469762049, 998244353, 1004535809}; // 三模数 12 const LL g = 3; 13 const long double eps = 1e-03; 14 15 LL multi (LL a, LL b, LL p) { // 快速乘 16 a %= p, b %= p; 17 return ((a * b - (LL) ((LL) ((long double) a / p * b + eps) * p)) % p + p) % p; 18 } 19 LL power (LL x, LL p, LL mod) { 20 LL cnt = 1; 21 while (p) { 22 if (p & 1) 23 cnt = cnt * x % mod; 24 25 x = x * x % mod; 26 p >>= 1; 27 } 28 29 return cnt; 30 } 31 const LL invg[3]= {power (g, MOD[0] - 2, MOD[0]), power (g, MOD[1] - 2, MOD[1]), power (g, MOD[2] - 2, MOD[2])}; 32 33 int N, M; 34 LL P; 35 36 LL A[MAXN], B[MAXN]; 37 38 int limit; 39 int oppo[MAXN]; 40 void NTT (LL* a, int inv, int type) { 41 for (int i = 0; i < limit; i ++) 42 if (i < oppo[i]) 43 swap (a[i], a[oppo[i]]); 44 for (int mid = 1; mid < limit; mid <<= 1) { 45 LL ome = power (inv == 1 ? g : invg[type], (MOD[type] - 1) / (mid << 1), MOD[type]); 46 for (int n = mid << 1, j = 0; j < limit; j += n) { 47 LL x = 1; 48 for (int k = 0; k < mid; k ++, x = x * ome % MOD[type]) { 49 LL a1 = a[j + k], xa2 = x * a[j + k + mid] % MOD[type]; 50 a[j + k] = (a1 + xa2) % MOD[type]; 51 a[j + k + mid] = (a1 - xa2 + MOD[type]) % MOD[type]; 52 } 53 } 54 } 55 } 56 57 LL ntta[3][MAXN], nttb[3][MAXN]; 58 void NTT_Main () { 59 int n, lim = 0; 60 for (n = 1; n <= N + M; n <<= 1, lim ++); 61 limit = n; 62 for (int i = 0; i < n; i ++) 63 oppo[i] = (oppo[i >> 1] >> 1) | ((i & 1) << (lim - 1)); 64 for (int i = 0; i < 3; i ++) { 65 for (int j = 0; j < n; j ++) 66 ntta[i][j] = A[j]; 67 for (int j = 0; j < n; j ++) 68 nttb[i][j] = B[j]; 69 NTT (ntta[i], 1, i); 70 NTT (nttb[i], 1, i); 71 for (int j = 0; j < n; j ++) 72 ntta[i][j] = ntta[i][j] * nttb[i][j] % MOD[i]; 73 NTT (ntta[i], - 1, i); 74 LL invn = power (n, MOD[i] - 2, MOD[i]); 75 for (int j = 0; j <= N + M; j ++) 76 ntta[i][j] = ntta[i][j] * invn % MOD[i]; 77 } 78 } 79 80 LL ans[MAXN]; 81 void CRT () { 82 LL m = MOD[0] * MOD[1]; 83 LL M1 = MOD[1], M2 = MOD[0]; 84 LL t1 = power (M1, MOD[0] - 2, MOD[0]), t2 = power (M2, MOD[1] - 2, MOD[1]), invM = power (m % MOD[2], MOD[2] - 2, MOD[2]); 85 for (int i = 0; i <= N + M; i ++) { 86 LL a1 = ntta[0][i], a2 = ntta[1][i], a3 = ntta[2][i]; 87 LL A = (multi (a1 * M1 % m, t1 % m, m) + multi (a2 * M2 % m, t2 % m, m)) % m; 88 LL k = ((a3 - A % MOD[2]) % MOD[2] + MOD[2]) % MOD[2] * invM % MOD[2]; 89 ans[i] = ((k % P * (m % P) % P + A % P) % P + P) % P; 90 } 91 } 92 93 int getnum () { 94 int num = 0; 95 char ch = getchar (); 96 97 while (! isdigit (ch)) 98 ch = getchar (); 99 while (isdigit (ch)) 100 num = (num << 3) + (num << 1) + ch - '0', ch = getchar (); 101 102 return num; 103 } 104 105 int main () { 106 N = getnum (), M = getnum (), P = (LL) getnum (); 107 for (int i = 0; i <= N; i ++) 108 A[i] = (LL) getnum (); 109 for (int i = 0; i <= M; i ++) 110 B[i] = (LL) getnum (); 111 112 NTT_Main (); 113 CRT (); 114 for (int i = 0; i <= N + M; i ++) { 115 if (i) 116 putchar (' '); 117 printf ("%lld", ans[i]); 118 } 119 puts (""); 120 121 return 0; 122 } 123 124 /* 125 5 8 28 126 19 32 0 182 99 95 127 77 54 15 3 98 66 21 20 38 128 */