多项式模板
// 洛谷多项式模板题,整理了自己的数论FFT,NTT的模板。
P3803 【模板】多项式乘法(FFT)
#include<cstdio> #include<iostream> #include<complex> #include<cmath> const double pi = acos(-1); // #define cp complex<double> using namespace std; const int maxn = 1500100; struct cp { double x, y; cp(double xx=0, double yy=0):x(xx), y(yy) {} cp operator+(const cp& b) { return cp(x+b.x, y+b.y); } cp operator-(const cp& b) { return cp(x-b.x, y-b.y); } cp operator*(const cp& b) { return cp(x*b.x-y*b.y, x*b.y+y*b.x); } cp operator*=(const cp& b) { double tmpx = x*b.x - y*b.y; y = x*b.y + y*b.x; x = tmpx; return *this; } }; int n, m; cp a[maxn*4], b[maxn*4]; // 一定要4倍内存!!! int rev[maxn*4]; void FFT(cp a[], int n, int inv) { // if(!n) return; for(int i=0;i<n;i++) { // rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1)); if(i<rev[i]) //不加这条if会交换两次(就是没交换) swap(a[i], a[rev[i]]); } for(int mid=1;mid<n;mid*=2) { //mid是准备合并序列的长度的二分之一 cp temp(cos(pi/mid), inv*sin(pi/mid)); for(int i=0;i<n;i+=mid*2) { //mid*2是准备合并序列的长度,i是合并到了哪一位 cp w(1, 0); for(int j=0;j<mid;j++,w*=temp) {//只扫左半部分,得到右半部分的答案 cp Nx = a[i+j]; cp Ny = w*a[i+j+mid]; a[i+j] = Nx + Ny; a[i+j+mid] = Nx - Ny; } } } } int main() { scanf("%d %d", &n, &m); for(int i=0;i<=n;i++) { scanf("%lf", &a[i].x); } for(int i=0;i<=m;i++) { scanf("%lf", &b[i].x); } int len = 1, bit = 0; while(len<=n+m) len <<= 1, ++bit; for(int i=0;i<len;i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1)); FFT(a, len, 1); FFT(b, len, 1); for(int i=0;i<=len;i++) a[i] *= b[i]; FFT(a, len, -1); for(int i=0;i<=n+m;i++) printf("%d ", (int)(a[i].x/len + 0.5)); return 0; }
P1919 【模板】A*B Problem升级版
利用上面的FFT模板,可以用来做大数乘法。
朴素的模拟手算乘法时间复杂度O(n^2),n为A,B的位数。改用FFT时间复杂度将优化到 O(nlogn)。
#include<cstdio> #include<iostream> #include<cmath> const double pi = acos(-1); using namespace std; const int maxn = 60100; int n; struct cp { // 复数类定义 double x, y; cp(double xx=0, double yy=0):x(xx), y(yy) {} cp operator+(const cp& b) { return cp(x+b.x, y+b.y); } cp operator-(const cp& b) { return cp(x-b.x, y-b.y); } cp operator*(const cp& b) { return cp(x*b.x-y*b.y, x*b.y+y*b.x); } cp operator*=(const cp& b) { double tmpx = x*b.x - y*b.y; y = x*b.y + y*b.x; x = tmpx; return *this; } }; // FFT模板 int rev[maxn*4]; void FFT(cp a[], int n, int inv) { // if(!n) return; for(int i=0;i<n;i++) { if(i<rev[i]) swap(a[i], a[rev[i]]); } for(int mid=1;mid<n;mid*=2) { cp temp(cos(pi/mid), inv*sin(pi/mid)); for(int i=0;i<n;i+=mid*2) { cp w(1, 0); for(int j=0;j<mid;j++,w*=temp) { cp Nx = a[i+j]; cp Ny = w*a[i+j+mid]; a[i+j] = Nx + Ny; a[i+j+mid] = Nx - Ny; } } } } cp a[maxn*4], b[maxn*4]; // 一定要4倍内存!!! int ans[maxn*4]; int main() { scanf("%d", &n); for(int i=n-1;i>=0;i--) { scanf("%1lf", &a[i].x); } for(int i=n-1;i>=0;i--) { scanf("%1lf", &b[i].x); } int len = 1, bit = 0; n <<= 1; while(len<=n) len <<= 1, ++bit; for(int i=0;i<len;i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1)); FFT(a, len, 1); FFT(b, len, 1); for(int i=0;i<len;i++) a[i] *= b[i]; FFT(a, len, -1); int carry = 0; for(int i=0;i<len;i++) { int now = (int)(a[i].x/len + 0.5); ans[i] = (carry + now) % 10; carry = (carry + now) / 10; } int s = len-1; while(s>=0 && ans[s]==0) --s; for(int i=s;i>=0;i--) printf("%d", ans[i]); puts(""); return 0; }
P4238 【模板】多项式求逆
题意:
给一个多项式 a[x] = ∑ ai*x^i ,求 F[x] 满足 a[x] * F[x] = 1 (mod x^n) 。
求解原理:
假设我们找到 a[x] * G[x] = 1 (mod x^(n/2)) 时的解 G[x]
又由于 a[x] * F[x] = 1
所以 a[x] * ( F[x] - G[x] ) = 0 (mod x^(n/2))
则 F[x] - G[x] = 0 (mod x^(n/2))
平方后,得
F[x] * F[x] + G[x] * G[x] - 2 * F[x] * G[x] = 0 (mod x^n)
再乘上 a[x]
F[x] + a[x] * G[x] * G[x] - 2 * G[x] = 0 (mod x^n)
所以
F[x] = 2 * G[x] - a[x] * G[x] * G[x]
我们找到了递推式,当 n = 1 时,G[0] 为 a[0] 的逆元。
模板适用条件:
模数 mod = 998244353 等原根为 3 的素数。
(代码模板见下一题。)
P4725 【模板】多项式对数函数(多项式 ln)
原理:
ln(A[x]) ==> B[x]
两边求导,得
A'[x] / A[x] = B'[x]
积分,得
B[x] = ∫ ( A'[x] / A[x] )
其中 1 / A[x] 可以用前面的多项式求逆模板。
所以
B[x] = Interx( A'[x] / A[x] ) = Interx( Dx(A[x]) * Inv(A[x]) )
这两题均使用了NTT和多项式求逆,下面的代码为P4725,P4238将main函数里的调用POLYMONY::Inv即可。
代码:
#include<cstdio> #include<iostream> #include<cmath> #include<cstring> using namespace std; typedef long long ll; const int maxn = 400100; // 多项式相乘时需要开4倍大小 namespace POLYNOME { const int mod = 998244353; const int g = 3; // 原根: 3 为 mod 的原根 int rev[maxn]; int myPow(ll a, int n) { ll res = 1; while(n) { if(n&1) res = res*a % mod; a = a*a % mod; n >>= 1; } return res; } // NTT模板 // inv==1 快速数论变换 // inv==-1 逆变换 注意最后点值除以长度为系数结果!!! void NTT(ll a[], int n, int inv) { // if(!n) return; for(int i=0;i<n;i++) { // rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1)); if(i<rev[i]) swap(a[i], a[rev[i]]); } for(int mid=1;mid<n;mid*=2) { ll tmp = myPow(g, (mod-1)/(mid*2)); if(inv==-1) tmp = myPow(tmp, mod-2); for(int i=0;i<n;i+=mid*2) { ll w = 1; for(int j=0;j<mid;j++,w=w*tmp%mod) { int x = a[i+j]; int y = w * a[i+j+mid] % mod; a[i+j] = (x + y) % mod; a[i+j+mid] = ((x - y)%mod + mod) % mod; } } } // if(inv==-1) { // for(int i=0;i<n;i++) // a[i]/n // a[i] = a[i]* myPow(n, mod-2) % mod; // } } // 多项式求逆 // Inv(a[x]) = 1/a[x] => G[x] ll F[maxn]; void Inv(ll a[], ll G[], int n) { if(n==1) { G[0] = myPow(a[0], mod-2); return; } Inv(a, G, n+1>>1); ll lim = 1, bit = 0; while(lim<(n<<1)) lim<<=1, ++bit; for(int i=1;i<lim;i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1)); memset(F, 0, sizeof(0)); for(int i=0;i<n;i++) F[i] = a[i]; NTT(G, lim, 1); NTT(F, lim, 1); for(int i=0;i<lim;i++) G[i] = (2- F[i]*G[i]%mod + mod)%mod * G[i] % mod; NTT(G, lim, -1); for(int i=0;i<lim;i++) // G[i]/lim G[i] = G[i]* myPow(lim, mod-2) % mod; for(int i=n;i<lim;i++) G[i] = 0; } // 多项式求导 // Dx(A[x]) => B[x] ll Da[maxn]; void Dx(ll A[], ll B[], int n) { for(int i=1;i<n;i++) B[i-1] = A[i]*i % mod; B[n-1] = 0; } // 多项式积分 // Inter(A[x]) => B[x] ll Sa[maxn]; void Interx(ll A[], ll B[], int n) { for(int i=1;i<n;i++) B[i] = A[i-1]*myPow(i, mod-2) % mod; B[0] = 0; } // 多项式求对数 // ln(A[x]) => B[x] // B[x] = Interx(A'(x)/A(x)) = Interx(Dx(A[x])*Inv(A[x])) // Da * Sa void lnx(ll a[], ll res[], int n) { Inv(a, Sa, n); Dx(a, Da, n); int lim = 1; int bit = 0; while(lim<(n<<1)) lim<<=1, ++bit; for(int i=0;i<lim;i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1)); NTT(Da, lim, 1); NTT(Sa, lim, 1); for(int i=0;i<lim;i++) Da[i] = Da[i] * Sa[i] % mod; NTT(Da, lim, -1); for(int i=0;i<lim;i++) // Da[i]/lim Da[i] = Da[i]* myPow(lim, mod-2) % mod; Interx(Da, res, n); } } int n; ll a[maxn], ans[maxn]; int main() { scanf("%d", &n); for(int i=0;i<n;i++) { scanf("%lld", &a[i]); } POLYNOME::lnx(a, ans, n); for(int i=0;i<n;i++) printf("%lld ", ans[i]); return 0; }