NTT学习笔记
NTT学习笔记
前言
FFT
- 我们知道\(FFT\)可以快速的完成两个多项式的乘法,利用了单位复根的特殊性质。
- 由于复数的实部与虚部是正余弦函数,需要做浮点数运算,以及产生误差。
- 这样计算量比较大,而且复数不可以取模。
NTT
-
中文名:快速数论变换。
-
多项式乘法有时候会建立在模域,对一些特殊的大质数取模时,可以考虑用原根\(g\)来代替,而这些特殊的大质数的原根恰好满足了某些性质,使得多项式乘法在模域中也可以快速的分治合并。
前置知识
阶
- 若\(a,p\)互质,且\(p>1\)。
- 对于\(a^n\equiv 1(mod\ p)\)最小的\(n\),我们成为\(a\)模\(p\)的阶,记做\(\delta_p(a)\)。
- 例如:\(\delta_7(2)=3\)。
原根
-
设\(p\)是正整数,\(a\)是整数,若\(\delta_p(a)\)等于\(\varphi(p)\),则称\(a\)为模\(p\)的一个原根。
-
比如说\(\delta_7(3)=6=\varphi(7)\),因此\(3\)是模\(7\)的一个原根。
-
重要定理:(其实只要知道这个就行了)
-
对于\(g,p\in Z\),如果\(g^i\ mod\ p(1\leq i\leq p-1)\)的值互不相同,则称\(g\)是\(p\)的原根。
-
常见的模数有\(998244353,1004535809,469762049\),这几个数的原根都是\(3(g=3)\)。
NTT
- \(FFT\)能够大大优化多项式乘法是因为单位复根有特殊且优秀的性质。
- 原根也有。
- 在\(NTT\)中,用原根来代替\(FFT\)中的单位复根。
- 任意模数\(NTT\)以后再说。
洛谷3803:多项式乘法
代码和FFT挺像的。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
template<typename T>inline void read(T &x){
x=0;
static int p;p=1;
static char c;c=getchar();
while(!isdigit(c)){if(c=='-')p=-1;c=getchar();}
while(isdigit(c)) {x=(x<<1)+(x<<3)+(c-48);c=getchar();}
x*=p;
}
const int maxn = 5e6 + 10;
const int mod = 998244353;
int n, m, a[maxn], b[maxn], limit=1, bit;
int rev[maxn];
ll qmi(ll a, ll b)
{
ll res = 1; res %= mod;
while(b)
{
if(b&1) res = (res*a) % mod;
b >>= 1;
a = (a*a)%mod;
} return res%mod;
}
void NTT(int c[], int op)
{
for(int i = 0; i < limit; i++)
if(i < rev[i]) swap(c[i], c[rev[i]]);
for(int mid = 1; mid < limit; mid <<= 1)
{
ll gn = qmi(3, (mod-1)/(mid<<1));
if(op == -1) gn = qmi(gn, mod-2);
for(int j = 0, R = mid<<1; j < limit; j += R)
{
ll g = 1;
for(int k = 0; k < mid; k++, g = (g*gn)%mod)
{
int x = c[j+k], y = g*c[j+k+mid]%mod;
c[j+k] = (x+y)%mod;
c[j+k+mid] = (x-y+mod)%mod;
}
}
}
}
int main()
{
read(n), read(m);
for(int i = 0; i <= n; i++) read(a[i]);
for(int i = 0; i <= m; i++) read(b[i]);
limit = 1;
while(limit <= n+m) limit <<= 1, bit++;
for(int i = 0; i < limit; i++)
rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1));
NTT(a, 1); NTT(b, 1);
for(int i = 0; i < limit; i++) a[i] = 1ll*a[i]*b[i]%mod;
NTT(a, -1);
ll inv = qmi(limit, mod-2);
for(int i = 0; i <= n+m; i++)
printf("%d ", (a[i]*inv)%mod);
return 0;
}