多项式
这是优美的多项式家族
快速傅里叶变换(FFT)
问题:多项式乘法
原理先不写了,思想就是把系数表达转化为点值表达,点值运算之后再变回系数表达,复杂度\(O(nlogn)\)
点值选取的是负数域中的n次单位根
有时间会补上这块内容的
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
const int N = 4e6;
const double Pi = acos(-1.0);
using namespace std;
struct node
{
double x,y;
}a[N + 5],b[N + 5],w[N + 5];
int n,m,maxn,rev[N + 5],lg;
node operator +(node a,node b)
{
return (node){a.x + b.x,a.y + b.y};
}
node operator -(node a,node b)
{
return (node){a.x - b.x,a.y - b.y};
}
node operator *(node a,node b)
{
return (node){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};
}
void fft(node *a,int typ)
{
for (int i = 0;i < maxn;i++)
if (i < rev[i])
swap(a[i],a[rev[i]]);
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < maxn;j += i << 1)
for (int k = 0;k < i;k++)
{
node x = a[k + j],t = (node){w[i + k].x,w[i + k].y * typ} * a[k + j + i];
a[k + j] = x + t;
a[k + j + i] = x - t;
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i = 0;i <= n;i++)
scanf("%lf",&a[i].x);
for (int i = 0;i <= m;i++)
scanf("%lf",&b[i].x);
maxn = 1;
while (maxn <= m + n)
maxn <<= 1,lg++;
for (int i = 0;i <= maxn;i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < i;j++)
w[i + j] = (node){cos(Pi * j / i),sin(Pi * j / i)};
fft(a,1);
fft(b,1);
for (int i = 0;i < maxn;i++)
a[i] = a[i] * b[i];
fft(a,-1);
for (int i = 0;i <= n + m;i++)
printf("%d ",(int)(a[i].x / maxn + 0.1));
return 0;
}
快速数论变换(NTT)
就是把问题转化为了在模意义下,于是我们可以选择和单位根有类似性质的原根,时间复杂度仍是\(O(nlogn)\)
#include <iostream>
#include <cstdio>
#include <algorithm>
const int N = 5e6;
const int P = 998244353;
using namespace std;
int n,m,rev[N + 5],maxn,lg,a[N + 5],b[N + 5],g[N + 5][3];
int mypow(int a,int x)
{
int s = 1;
while (x)
{
if (x & 1)
s = 1ll * s * a % P;
a = 1ll * a * a % P;
x >>= 1;
}
return s;
}
void ntt(int *a,int typ)
{
for (int i = 0;i < maxn;i++)
if (i < rev[i])
swap(a[i],a[rev[i]]);
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < maxn;j += i << 1)
for (int k = 0;k < i;k++)
{
int x = a[k + j],t = 1ll * g[k + i][typ] * a[k + i + j] % P;
a[k + j] = (x + t) % P;
a[k + i + j] = ((x - t) % P + P) % P;
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i = 0;i <= n;i++)
scanf("%d",&a[i]);
for (int i = 0;i <= m;i++)
scanf("%d",&b[i]);
maxn = 1;
while (maxn <= n + m)
maxn <<= 1,lg++;
for (int i = 0;i <= maxn;i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
for (int i = 1;i < maxn;i <<= 1)
{
int G1 = mypow(3,(P - 1) / (i << 1)),G2 = mypow(mypow(3,P - 2),(P - 1) / (i << 1));
g[i][1] = 1;
g[i][0] = 1;
for (int j = 1;j < i;j++)
g[i + j][1] = 1ll * g[i + j - 1][1] * G1 % P,g[i + j][0] = 1ll * g[i + j - 1][0] * G2 % P;
}
ntt(a,1);
ntt(b,1);
for (int i = 0;i < maxn;i++)
a[i] = 1ll * a[i] * b[i] % P;
ntt(a,0);
int inv = mypow(maxn,P - 2);
for (int i = 0;i <= n + m;i++)
printf("%d ",1ll * a[i] * inv % P);
return 0;
}
多项式求逆
问题:给定一个多项式\(F(x)\),求一个多项式\(G(x)\),满足\(F(x)G(x)\equiv 1(mod\ x^n)\)
假设我们已经求出了一个\(F(x)\)在\(mod\ x^n\)下的逆\(G'(x)\),我们要求在\(mod\ x^{2n}\)下的逆\(G(x)\)
那么考虑
于是就可以愉快地递归求解了,时间复杂度\(T(n)=T(n/2)+O(nlogn)=O(nlogn)\)
Code
int INVa[N + 5];
void INV(int *a,int *ans,int n)
{
if (n == 1)
{
ans[0] = mypow(a[0],p - 2);
return;
}
INV(a,ans,n + 1 >> 1);
pre(n * 2);
for (int i = 0;i < n;i++)
INVa[i] = a[i];
clear(INVa,maxn,n);
ntt(INVa,1);
ntt(ans,1);
for (int i = 0;i < maxn;i++)
ans[i] = (2ll * ans[i] % p - 1ll * INVa[i] * ans[i] % p * ans[i] % p) % p;
ntt(ans,0);
clear(ans,maxn,n);
}
多项式对数函数(多项式 ln)
问题:给出 \(n-1\) 次多项式 \(A(x)\),求一个 \(\bmod{\:x^n}\) 下的多项式 \(B(x)\),满足 \(B(x) \equiv \ln A(x)\).
对两边同时求导\(B'(x)\equiv \frac{A'(x)}{A(x)}\)
积分回去\(B(x)\equiv \int \frac{A'(x)}{A(x)}dx\)
然后就是求导公式和积分公式
Code
int Lna[N + 5],Lnb[N + 5];
void DOV(int *a,int *f,int n)
{
for (int i = 1;i < n;i++)
f[i - 1] = 1ll * i * a[i] % p;
f[n - 1] = 0;
}
void DOVINV(int *a,int *f,int n)
{
f[0] = 0;
for (int i = 1;i < n;i++)
f[i] = 1ll * mypow(i,p - 2) * a[i - 1] % p;
}
void Ln(int *a,int *ans,int n)
{
DOV(a,Lna,n);
pre(n * 2);
clear(Lnb,maxn);
INV(a,Lnb,n);
pre(n * 2);
clear(Lna,maxn,n);
ntt(Lna,1);
ntt(Lnb,1);
for (int i = 0;i < maxn;i++)
Lna[i] = 1ll * Lna[i] * Lnb[i] % p;
ntt(Lna,0);
DOVINV(Lna,ans,n);
clear(ans,maxn,n);
}
多项式指数函数(多项式 exp)
问题:给出 \(n-1\) 次多项式 \(A(x)\),保证\(A_0=0\),求一个 \(\bmod{\:x^n}\) 下的多项式 \(B(x)\),满足 \(B(x) \equiv \text e^{A(x)}\)。
考虑用牛顿迭代解决这个问题
设\(F(B(x))=lnB(x)-A(x)\)
把\(A(x)\)看作常数项,所以\(F'(B(x))=\frac{1}{B(x)}\)
代入牛顿迭代的式子有
倍增求解即可
Code
int expa[N + 5],expb[N + 5];
void exp(int *a,int *ans,int n)
{
if (n == 1)
{
ans[0] = 1;
return;
}
exp(a,ans,n + 1 >> 1);
Ln(ans,expa,n);
pre(n * 2);
for (int i = 0;i < n;i++)
expb[i] = a[i];
clear(expb,maxn,n);
ntt(ans,1);
ntt(expa,1);
ntt(expb,1);
for (int i = 0;i < maxn;i++)
ans[i] = 1ll * ans[i] * ((1 - expa[i] + expb[i]) % p) % p;
ntt(ans,0);
clear(ans,maxn,n);
}
多项式快速幂
问题:给定一个 \(n-1\) 次多项式 \(A(x)\),求一个在 \(\bmod\ x^n\) 意义下的多项式 \(B(x)\),使得 \(B(x) \equiv A^k(x) \ (\bmod\ x^n)\)
我们对两边先ln再exp可以得到
于是\(k\)也可以取模了
然后注意到数据不一定保证\(A_0=1\),那么我们可以找到第一个非\(0\)的项\(a\),把\(A(x)\)的每一项都除以\(a\),变成\(\frac{A(x)}{a}\),并将后面的移到前面,这样就可以保证\(A_0=1\),最后再乘\(a^k\)并且处理\(0\)即可
Code
int pa[N + 5];
void mypow(int *a,int *ans,int n,int k)
{
Ln(a,pa,n);
for (int i = 0;i < n;i++)
pa[i] = 1ll * pa[i] * k % p;
exp(pa,ans,n);
}
多项式开根
问题:给定一个\(n-1\)次多项式\(A(x)\),求一个在\(\bmod\ x^n\)意义下的多项式\(B(x)\),使得\(B^2(x) \equiv A(x) \ (\bmod\ x^n)\)。若有多解,请取零次项系数较小的作为答案。
设\(H^2(x)\equiv F(x)(mod\ x^n)\)
那么考虑
倍增即可,只有一项的时候需要用二次剩余求根号
不过其实也可以先ln再exp回去
Code
int sqra[N + 5],sqrtmp[N + 5];
void sqr(int *a,int *ans,int n)
{
if (n == 1)
{
ans[0] = sq;
return;
}
sqr(a,ans,n + 1 >> 1);
pre(n * 2);
clear(sqra,maxn);
clear(sqrtmp,maxn);
INV(ans,sqra,n);
pre(n * 2);
for (int i = 0;i < n;i++)
sqrtmp[i] = a[i];
ntt(sqra,1);
ntt(sqrtmp,1);
ntt(ans,1);
int t = mypow(2,p - 2);
for (int i = 0;i < maxn;i++)
ans[i] = 1ll * ((sqrtmp[i] + 1ll * ans[i] * ans[i] % p) % p) * t % p * sqra[i] % p;
ntt(ans,0);
int inv = mypow(maxn,p - 2);
for (int i = 0;i < n;i++)
ans[i] = 1ll * ans[i] * inv % p;
clear(ans,maxn,n);
}
多项式除法
问题:给定一个\(n\)次多项式\(F(x)\)和一个\(m\)次多项式\(G(x)\),求出多项式\(Q(x),R(x)\)满足:
- \(Q(x)\)次数为\(n-m\),\(R(x)\)次数小于\(m\)
- \(F(x)=Q(x)G(x)+R(x)\)
首先设一个\(n\)项多项式\(A(x)\),假设一个\(r\)操作使得\(A_r(x)=x^nA(\frac{1}{x})\)
那么可以看出\(A_r[i]=A[n-i]\)
然后考虑下面的式子
于是我们对\(G_r(x)\)求逆,然后求得\(Q_r(x)\),再带回得到\(Q(x)\)
最后根据\(R(x)=F(x)-Q(x)G(x)\)求得\(R(x)\)
时间复杂度\(O(nlogn)\)
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
const int P = 998244353;
const int N = 1e6;
using namespace std;
int mypow(int a,int x)
{
int s = 1;
while (x)
{
if (x & 1)
s = 1ll * s * a % P;
a = 1ll * a * a % P;
x >>= 1;
}
return s;
}
int n,m,F[N + 5],G[N + 5],Q[N + 5],GR[N + 5],w[N + 5][3],maxn,lg,rev[N + 5],Gi[N + 5],c[N + 5],FR[N + 5];
void R(int *a,int *b,int n)
{
for (int i = 0;i <= n;i++)
b[i] = a[n - i];
}
void ntt(int *a,int typ)
{
for (int i = 0;i < maxn;i++)
if (i < rev[i])
swap(a[i],a[rev[i]]);
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < maxn;j += i << 1)
for (int k = 0;k < i;k++)
{
int x = a[j + k],t = 1ll * w[i + k][typ] * a[i + j + k] % P;
a[j + k] = (x + t) % P;
a[j + k + i] = ((x - t) % P + P) % P;
}
}
void ntt_pre(int n)
{
maxn = 1;
lg = 0;
while (maxn <= n)
maxn <<= 1,lg++;
for (int i = 0;i < maxn;i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
}
void INV(int n,int *a,int *b)
{
if (n == 1)
{
b[0] = mypow(a[0],P - 2);
return;
}
INV((n + 1) >> 1,a,b);
ntt_pre(n << 1);
for (int i = 0;i < n;i++)
c[i] = a[i];
for (int i = n;i < maxn;i++)
c[i] = 0;
ntt(c,1);
ntt(b,1);
for (int i = 0;i < maxn;i++)
b[i] = ((2ll * b[i] % P - 1ll * c[i] * b[i] % P * b[i] % P) % P + P) % P;
ntt(b,2);
int inv = mypow(maxn,P - 2);
for (int i = 0;i < n;i++)
b[i] = 1ll * b[i] * inv % P;
for (int i = n;i < maxn;i++)
b[i] = 0;
}
void NR(int *a,int *b,int n)
{
for (int i = 0;i <= n;i++)
b[n - i] = a[i];
}
int main()
{
scanf("%d%d",&n,&m);
for (int i = 0;i <= n;i++)
scanf("%d",&F[i]);
for (int i = 0;i <= m;i++)
scanf("%d",&G[i]);
maxn = 1;
while (maxn <= (n + m) * 2)
maxn <<= 1;
for (int i = 1;i < maxn;i <<= 1)
{
int G1 = mypow(3,(P - 1) / (i << 1)),G2 = mypow(mypow(3,P - 2),(P - 1) / (i << 1));
w[i][1] = w[i][2] = 1;
for (int j = 1;j < i;j++)
w[i + j][1] = 1ll * w[i + j - 1][1] * G1 % P,w[i + j][2] = 1ll * w[i + j - 1][2] * G2 % P;
}
R(G,GR,m);
INV(n - m + 2,GR,Gi);
R(F,FR,n);
ntt_pre(n * 2 - m + 2);
ntt(FR,1);
ntt(Gi,1);
for (int i = 0;i < maxn;i++)
Gi[i] = 1ll * Gi[i] * FR[i] % P;
ntt(Gi,2);
int inv = mypow(maxn,P - 2);
for (int i = 0;i < maxn;i++)
Gi[i] = 1ll * Gi[i] * inv % P;
NR(Gi,Q,n - m);
for (int i = 0;i <= n - m;i++)
printf("%d ",Q[i]);
cout<<endl;
for (int i = n - m + 1;i < maxn;i++)
Q[i] = 0;
ntt_pre(n + m);
ntt(Q,1);
ntt(G,1);
ntt(F,1);
for (int i = 0;i < maxn;i++)
F[i] = ((F[i] - 1ll * Q[i] * G[i] % P) % P + P) % P;
ntt(F,2);
inv = mypow(maxn,P - 2);
for (int i = 0;i < m;i++)
printf("%d ",1ll * F[i] * inv % P);
return 0;
}