浅谈多项式全家桶与实现
upd:多项式幂函数(加强版)中将 log(n)
改为 5
2021/5/8 upd:修补了一些锅,增加了大量技术,多项式全家桶基本完成
广告
现已加入:
-
分治 NTT -
分治 + NTT
写在前面
众所周知,写多项式最麻烦的事情就是封装、清空和边界
所以务必考虑好每一步细节,算好边界,记得清空
如果你遇到以下玄学错误:
-
RE
了?多半是因为没开大栈空间(
请在编译选项里加入
--std=c++11 -Wl,--stack=1145149191810
不过这个问题最好的解决办法是在函数内开数组时使用
static
-
WA
了?多半是因为没清空
-
TLE
了?多半是因为……板子常数太大了……
尽量别用
memset
,小心T飞
我的习惯:
NTT
和 INV
比较常用,单独开在两个 namespace
里
还有求值和插值,分别开在两个 namespace
里
下降幂多项式专门开在 FallingPoly
里
其它的函数都开在 poly
里
函数形如 void (long long *s,long long *f,lnog long n)
s
表示的是输出数组(为了方便封装,s
和f
可以是同一个数组)f
表示的是输入数组n
表示的是多项式次数(即数组下标为 \(0-n\))
大写字母 A
B
C
等数组表示的是临时数组
最后一些约定:
-
本文中题目及推导中的 \(n\) 均是多项式次数加一(即与代码中不一样)
-
本文中非特殊说明,模数均为 \(998244353\) (
-
本文中输入多项式次数均少于 \(10^5\) ,因此数组大小
N
一般开到 \(3 \times 10^5\) (因为最大可以到 \(262144\)不过由于未知原因多项式开根要开到 \(6 \times 10^5\)
0. 前置知识
背~下~来
-
多项式
多项式是啥?
多项式就是一堆单项式的和(好废话,不唠叨了
-
泰勒展开
没学过所以不敢乱讲
可以看看这篇劲爆的介绍 [怎样更好地理解并记忆泰勒展开式?——知乎](怎样更好地理解并记忆泰勒展开式? - 知乎 (zhihu.com))
-
牛顿迭代法
牛顿迭代实际上是快速求函数零点的东西
大概就是,随便选出一个点 \((x_0,f(x_0))\),作出过它的切线 \(y=f'(x_0)(x-x_0)+f(x_0)\)
如果求出切线的零点 \((x_0-\dfrac{f(x_0)}{f'(x_0)},0)\) ,就会发现,它离真正的零点更近了!
这样一直迭代下去,就会收敛于零点
换成多项式,比如想求使得 \(F(G(x)) \equiv 0 \pmod x^n\) 的 \(G(x)\)
假设已经求出了 \(F(G_0(x)) \equiv 0 \pmod{x^\frac{n}{2}}\)
求出 \(G(x) = G_0(x) - \dfrac{F(G_0(x))}{F'(G_0(x))}\)
有 \(F(G(X)) \equiv 0 \pmod{x^n}\)
递归求下去,时间复杂度是 \(O(T)=O(\dfrac{T}{2})+O(T \log T)=O(n \log n)\)
我也不知道为什么,背下来
1. FFT/NTT
多项式乘法,或者叫多项式卷积
最基础的多项式操作,其它运算都以此为基础
原理较为复杂,本文不多加解释
\(O(n \log n)\)
pre()
用于预处理 \(\omega_l\) 以加速,记得先在主函数内调用一次
long long pr[N];
namespace NTT{
static long long A[N],B[N],rev[N];
void pre(){
for(long long l=1;l<N;l<<=1ll)
pr[l]=ksm(g,(mod-1)/(l*2));
}
long long init(long long n,long long m){
long long lim=0;
while((1ll<<lim)<=n+m) lim++;
for(long long i=0;i<=(1<<lim)-1;i++)
rev[i]=(rev[i>>1ll]>>1ll) | ((i & 1ll)<<(lim-1));
return lim;
}
void ntt(long long *f,long long n,long long opt){
for(long long i=0;i<=n-1;i++)
if(i<rev[i]) swap(f[i],f[rev[i]]);
for(long long l=1;l<n;l<<=1ll){
long long tmp=pr[l];
if(opt==-1) tmp=ksm(tmp,mod-2);
for(long long i=0;i<=n-1;i+=l<<1ll){
long long omegf=1;
for(long long j=0;j<l;j++){
long long x=f[i+j],y=omegf*f[i+j+l]%mod;
f[i+j]=(x+y)%mod,f[i+j+l]=(x-y+mod)%mod;
omegf=omegf*tmp%mod;
}
}
}
if(opt==-1){
long long t=ksm(n,mod-2);
for(long long i=0;i<=n-1;i++)
f[i]=f[i]*t%mod;
}
}
void solve(long long *s,long long* f,long long* g,long long n,long long m){
long long lim=init(n,m);
for(long long i=0;i<=n;i++) A[i]=f[i];
for(long long i=0;i<=m;i++) B[i]=g[i];
//不要破坏f和g,先拿A和B寄存
for(long long i=n+1;i<=(1ll<<lim);i++) A[i]=0;
for(long long i=m+1;i<=(1ll<<lim);i++) B[i]=0;
//清空
ntt(A,(1<<lim),1);
ntt(B,(1<<lim),1);
for(long long i=0;i<=(1<<lim)-1;i++) s[i]=A[i]*B[i]%mod;
ntt(s,(1<<lim),-1);
}
}
2. 多项式求逆
-
给出 \(F(x)\) ,求出 \(G(x)\) 使得 \(F(x)G(x) \equiv 1 \pmod {x^n}\)
记 \(H(t)=F(x)- \dfrac{1}{t}\) ,则有 \(H(G(x)) \equiv 0 \pmod{x^n}\)
- 这里 \(F(x)\) 看作常数!
就是找 \(H(G(x))\) 的零点了,上牛顿迭代
假设已经求出了 \(H(G_0(x)) \equiv 0 \pmod{x^\frac{n}{2}}\)
递归边界 \(G_0(x) = \dfrac{1}{f_0}\)
\(O(n \log n)\)
- 计算时每次由 \(\frac{len}{2}\) 转移到 \(len\) ,注意这里的 \(len\) 实际上是最高次数加一,即实际上次数为 \(len-1\)
- 虽然每次由 \(\frac{len}{2}\) 转移到 \(len\) ,但在计算时可能会到 \(lim=len \times 2\) ,记得清空
namespace INV{
long long A[N],B[N],S[N];
void solve(long long *s,long long *f,long long n){
S[0]=ksm(f[0],mod-2);
S[1]=0;
long long len;
for(len=2;len<=(n<<1ll);len<<=1ll){
// ^^ 这里一定要写小于等于!
//len是现在处理的长度(即x^n)
long long lim=len<<1ll;
for(long long i=0;i<len;i++) A[i]=f[i],B[i]=S[i];
for(long long i=len;i<lim;i++) A[i]=B[i]=0;
NTT::init(len,0);
NTT::ntt(A,lim,1);
NTT::ntt(B,lim,1);
for(long long j=0;j<lim;j++)
S[j]=(2*B[j]+mod-A[j]*B[j]%mod*B[j]%mod)%mod;
NTT::ntt(S,lim,-1);
for(long long j=len;j<lim;j++) S[j]=0;
}
for(long long i=0;i<=n;i++) s[i]=S[i];
}
}
3. 多项式对数函数(ln)
-
给出 \(F(x)\) ,求 \(G(x)=ln(F(x)) \pmod{x^n}\) 。
-
保证 \(f_0=1\)
众所周知 \(ln\) 求导有很好的性质: \(ln(x)'=\frac{1}{x}\)
别忘了链式法则
求导,求逆,再积分即可
\(O(n \log n)\)
求导:
\((x^\alpha)'=\alpha x^{\alpha-1}\)
void Deriv(long long *s,long long *f,long long n){
A[n]=0;
for(long long i=1;i<=n;i++) A[i-1]=f[i]*i%mod;
for(long long i=0;i<=n;i++) s[i]=A[i];
}
积分:
void Limit(long long *s,long long *f,long long n){
A[0]=0;
for(long long i=0;i<=n-1;i++) A[i+1]=f[i]*ksm(i+1,mod-2)%mod;
for(long long i=0;i<=n;i++) s[i]=A[i];
}
void Ln(long long *s,long long *f,long long n){
static long long A[N],B[N];
int lim=NTT::init(n,n);
for(long long i=0;i<=n;i++) A[i]=f[i],B[i]=0;
for(int i=n+1;i<(1ll<<lim);i++) A[i]=B[i]=0;
Deriv(B,A,n);
INV::solve(A,A,n);
NTT::solve(A,A,B,n,n);
Limit(A,A,n);
for(long long i=0;i<=n;i++) s[i]=A[i];
}
4. 多项式指数函数(exp)
-
给出 \(F(x)\) ,求出 \(G(x)\) 使得 \(G(x) \equiv e^{F(x)} \pmod {x^{n}}\)
-
保证 \(f_0=0\)
记 \(H(t)=F(x)- ln(t)\) ,则有 \(H(G(x)) \equiv 0 \pmod{x^n}\)
再上牛顿迭代
假设已经求出了 \(H(G_0(x)) \equiv 0 \pmod{x^\frac{n}{2}}\)
\(O(n \log n)\)
void Exp(long long *s,long long *f,long long n){
static long long A[N],B[N],C[N],S[N];
S[0]=1;
S[1]=0;
long long len;
for(len=2;len<=(n<<1ll);len<<=1ll){
long long lim=len<<1ll;
for(long long i=0;i<len;i++) A[i]=f[i],B[i]=S[i];
for(long long i=len;i<lim;i++) A[i]=B[i]=C[i]=0;
Ln(C,B,len-1);
for(long long i=0;i<len;i++) C[i]=(mod-C[i]+A[i])%mod;
C[0]=(C[0]+1)%mod;
NTT::init(len,0);
NTT::ntt(B,lim,1);
NTT::ntt(C,lim,1);
for(long long j=0;j<lim;j++) S[j]=B[j]*C[j]%mod;
NTT::ntt(S,lim,-1);
for(long long j=len;j<lim;j++) S[j]=0;
}
for(long long i=0;i<=n;i++) s[i]=S[i];
}
5. 多项式除法(取模)
-
给出 \(n\) 次多项式 \(F(x)\) 和 \(m\) 次多项式 \(G(x)\) ,求 \(n-m\) 次多项式 \(Q(x)\) 和不超过 \(m-1\) 次多项式 \(R(x)\) 使得 \(F(x)=Q(x)G(x)+R(x)\) 。
这次有两个多项式,牛顿迭代又用不了了
上一个黑科技:将系数对称
再试试:
考虑先求出 \(Q(x)\) ,模上 \(x^{n-m+1}\)
求出 \(Q(x)\) 后 \(R(x)\) 也很简单了
void Rev(long long *s,long long *f,long long n){
for(long long i=0;i<=n/2;i++) swap(f[i],f[n-i]);
}
void Div(long long *q,long long* r,long long *f,long long *g,long long n,long long m){
static long long A[N],B[N],C[N],D[N];
for(long long i=0;i<=n;i++) A[i]=C[i]=f[i];
for(long long i=0;i<=m;i++) B[i]=D[i]=g[i];
Rev(A,A,n);
Rev(B,B,m);
INV::solve(B,B,n-m);
NTT::solve(q,A,B,n,n-m);
for(long long i=n-m+1;i<=n+n-m;i++) q[i]=0;
Rev(q,q,n-m);
NTT::solve(D,D,q,m,n-m);
Del(r,C,D,n,n);
}
6. 多项式快速幂
-
给出多项式 \(F(x)\) ,求 \(F(x)^k \pmod{x^n}\)
-
保证 \(f_0=1\)
考虑取 \(\ln\)
这里的 \(k\) 可以对 \(\bmod\) 取模
\(O(n\log n)\)
void Pow(long long *s,long long *f,long long n,long long k){
int lim=NTT::init(n,n);
for(int i=n+1;i<(1ll<<lim);i++) A[i]=0;
for(long long i=0;i<=n;i++) A[i]=f[i];
Ln(A,A,n);
Mult(A,A,n,k);
Exp(A,A,n);
for(long long i=0;i<=n;i++) s[i]=A[i];
}
- 加强版:不保证 \(f_0=1\)
\(\ln\) 必须要保证 \(f_0=1\) ,咋办?
只能将 \(f_0\) 转化成 \(1\)
如果 \(f_0\) 不为 \(0\) 的话,直接除掉就行了
如果为 \(0\) 呢?
那就找到第一个不为 \(0\) 的系数(假设为第 \(t\) 项),然后除掉 \(x^t\) 即可
列一下式子:
记 \(p=[x^t]f(x)\)
还有点实现的小问题:当 \(tk>n\) 时就不用算了,因为全是 \(0\)
void Pow(long long *s,long long *f,long long n,long long k1,long long k2,long long lenk){
//k1 是对 mod 取模, k2 是对 mod-1 取模,lenk 是 k 的长度
static long long A[N];
long long lim=NTT::init(n,n);
long long t=0;
while(!f[t]) t++;
if(lenk>log(n) && t){
for(long long i=0;i<=n;i++) s[i]=0;
return ;
}
long long p=ksm(f[t],k2),q=ksm(f[t],mod-2);
n-=t;
for(long long i=0;i<=n;i++) A[i]=f[i+t]*q%mod;
for(long long i=n+1;i<(1ll<<lim);i++) A[i]=0;
Ln(A,A,n);
Mult(A,A,n,k1);
Exp(A,A,n);
for(long long i=0;i<=n+t;i++) s[i]=0;
for(long long i=0;i+t*k1<=n+t;i++)
s[i+t*k1]=A[i]*p%mod;
}
7. 多项式开根
-
给出多项式 \(F(x)\) ,求 \(G(x)\) 使 \(G(x)^2 \equiv F(x) \pmod{x^{n}}\)
-
保证 \(f_0=1\)
记 \(H(t)=F(x)- t^2\) ,则有 \(H(G(x)) \equiv 0 \pmod{x^n}\)
又上牛顿迭代
假设已经求出了 \(H(G_0(x)) \equiv 0 \pmod{x^\frac{n}{2}}\)
\(O(n \log n)\)
- 事实上开根可以得到两个答案(差别就是在边界条件),模板题中要求系数字典序最小,令 \(G_0(x)=1\) 即可
void Sqrt(long long *s,long long *f,long long n){
static long long A[N],B[N],C[N],S[N];
S[0]=1;
long long len;
for(len=1;len<=(n<<1ll);len<<=1ll){
long long lim=len<<1ll;
for(long long i=0;i<len;i++) A[i]=f[i],B[i]=S[i];
for(long long i=len;i<lim;i++) A[i]=B[i]=C[i]=0;
INV::solve(C,B,len);
NTT::init(len,0);
NTT::ntt(A,lim,1);
NTT::ntt(C,lim,1);
for(long long j=0;j<lim;j++) S[j]=A[j]*C[j]%mod;
NTT::ntt(S,lim,-1);
for(long long i=0;i<lim;i++) S[i]=(S[i]+B[i])%mod*inv2%mod;
for(long long j=len;j<lim;j++) S[j]=0;
}
for(long long i=0;i<=n;i++) s[i]=S[i];
}
- 不保证 \(F(0)=1\) ,但保证有解
其它都没啥问题,就是边界条件
奆佬您一定想到了二次剩余
这里我选择用 \(BSGS\) 求出 \(g^x \equiv F(0) \pmod{998244353}\)
然后 \(G_0(x)=g^{\frac{x}{2}}\)
void Sqrt(long long *s,long long *f,long long n){
static long long A[N],B[N],C[N],S[N];
long long p=BSGS(3,998244353,f[0]);
S[0]=ksm(3,p/2);
S[1]=0;
long long len;
for(len=1;len<=(n<<1ll);len<<=1ll){
long long lim=len<<1ll;
for(long long i=0;i<len;i++) A[i]=f[i],B[i]=S[i];
for(long long i=len;i<lim;i++) A[i]=B[i]=C[i]=0;
INV::solve(C,B,len);
NTT::init(len,0);
NTT::ntt(A,lim,1);
NTT::ntt(C,lim,1);
for(long long j=0;j<lim;j++) S[j]=A[j]*C[j]%mod;
NTT::ntt(S,lim,-1);
for(long long i=0;i<lim;i++) S[i]=(S[i]+B[i])%mod*inv2%mod;
for(long long j=len;j<lim;j++) S[j]=0;
}
for(long long i=0;i<=n;i++) s[i]=S[i];
}
8. 多项式三角函数
-
给出多项式 \(F(x)\) ,求 \(sin(x)\),\(cos(x)\) 与 \(tan(x)\)。
-
保证 \(f_0=0\)
总所周知复数的指数形式(欧拉公式):
顺便一提,\(i=\omega_4\)
\(O(n \log n)\)
void Sin(long long *s,long long *f,long long n){
long long I=mod-ksm(g,(mod-1)/4);
static long long A[N],B[N];
for(long long i=0;i<=n;i++) A[i]=f[i]*I%mod;
poly::Exp(A,A,n);
INV::solve(B,A,n);
long long t=ksm(2*I%mod,mod-2);
for(long long i=0;i<=n;i++) s[i]=(A[i]+mod-B[i])%mod*t%mod;
}
void Cos(long long *s,long long *f,long long n){
long long I=mod-ksm(g,(mod-1)/4);
static long long A[N],B[N];
for(long long i=0;i<=n;i++) A[i]=f[i]*I%mod;
poly::Exp(A,A,n);
INV::solve(B,A,n);
long long t=inv2;
for(long long i=0;i<=n;i++) s[i]=(A[i]+B[i])%mod*t%mod;
}
9. 多项式反三角函数
-
给出多项式 \(F(x)\) ,求 \(\arcsin(x)\),\(\arccos(x)\) 与 \(\arctan(x)\)。
-
保证 \(f_0=0\)
考虑将它转换出三角函数的圈子,像 ln
一样求导
于是:
由此看得出来 arccos
和 arccot
不用写了
\(O(n \log n)\)
void Arcsin(long long *s,long long *f,long long n){
static long long A[N],B[N];
for(long long i=0;i<=n;i++) A[i]=f[i],B[i]=f[i];
NTT::solve(A,A,A,n,n);
for(int i=0;i<=n;i++) A[i]=(mod-A[i])%mod;
A[0]=(A[0]+1)%mod;
poly::Sqrt(A,A,n);
INV::solve(A,A,n);
poly::Deriv(B,B,n);
NTT::solve(A,A,B,n,n);
poly::Limit(A,A,n);
for(long long i=0;i<=n;i++) s[i]=A[i];
}
void Arctan(long long *s,long long *f,long long n){
static long long A[N],B[N];
for(long long i=0;i<=n;i++) A[i]=f[i],B[i]=f[i];
NTT::solve(A,A,A,n,n);
A[0]=(A[0]+1)%mod;
INV::solve(A,A,n);
poly::Deriv(B,B,n);
NTT::solve(A,A,B,n,n);
poly::Limit(A,A,n);
for(long long i=0;i<=n;i++) s[i]=A[i];
}
10. 多项式多点求值
-
给出多项式 \(F(x)\) ,求给定 \(m\) 处的点值。
有一个很神奇的东西:
若 \(F(x)=Q(x)(x-x_0)+R(x)\),则 \(F(x_0)=R(x_0)\)
所以我们只要对每个 \(x_i\) 求出取模后的余式 \(R(x)\) 即可
考虑分治,先 分治 + NTT 处理出 \(\prod\limits_{i=l}^r (x-x_i)\)
当前获得了 \(F(x)=Q(x) \prod\limits_{i=l}^r (x-x_i) +R(x)\)
那么分治下去只要对两边分别取模就可以了,显然次数会减半
\(O(n\log^2n)\)
namespace Evaluation{
long long tmp[N];
vector <long long> P[N],Q[N];
long long lenp[N],lenq[N];
void init(long long o,long long l,long long r){
long long len=r-l+1,mid=(l+r)>>1ll;
lenp[o]=len;
lenq[o]=len-1;
P[o].resize(2*lenp[o]+2);
Q[o].resize(2*lenq[o]+2);
if(l==r) return ;
init(o<<1,l,mid);
init(o<<1|1,mid+1,r);
}
void solve1(long long *a,long long o,long long l,long long r){
if(l==r){
P[o]={(mod-a[l])%mod,1};
return ;
}
long long mid=(l+r)>>1;
solve1(a,o<<1,l,mid);
solve1(a,o<<1|1,mid+1,r);
NTT::solve(P[o].data(),P[o<<1].data(),P[o<<1|1].data(),lenp[o<<1],lenp[o<<1|1]);
}
void solve2(long long *s,long long *a,long long o,long long l,long long r){
if(r-l<=100){
for(long long i=l;i<=r;i++){
s[i]=0;
for(long long j=0,tmp=1;j<=lenq[o];j++,tmp=tmp*a[i]%mod) s[i]=(s[i]+tmp*Q[o][j]%mod)%mod;
}
return ;
}
long long mid=(l+r)>>1;
poly::Div(tmp,Q[o<<1].data(),Q[o].data(),P[o<<1].data(),lenq[o],lenp[o<<1]);
//左边取模
poly::Div(tmp,Q[o<<1|1].data(),Q[o].data(),P[o<<1|1].data(),lenq[o],lenp[o<<1|1]);
//右边取模
solve2(s,a,o<<1,l,mid);
solve2(s,a,o<<1|1,mid+1,r);
}
void solve(long long *s,long long *f,long long *a,long long n,long long m){
init(1,1,m);
solve1(a,1,1,m);
poly::Div(tmp,Q[1].data(),f,P[1].data(),n,lenp[1]);
solve2(s,a,1,1,m);
}
}
快速多项式多点插值
卡常数!
考虑多项式除法的常数实在是太大了,我们不想做那么多次多项式除法
考虑多项式除法的本质:
这里 \(Q(x)\) 是一次的,\(R(x)\) 是零次的,所以只要求出 \(Q(x)\) 的常数项即可得到答案
考虑分治:
试着推一下 \(G_{R0}^{-1}(x)\) 与 \(G_R^{-1}(x)\) 的关系:
于是:
所以算到底就相当于 \(F_R(x)\) 乘上若干个 \(G_{R}(x)\)
然而这没有用啊?时间复杂度不对
等一下,我们只需要 \(Q(x)\) 的常数项,也就是 \(Q_R(x)\) 的 \(x^{n-1}\) 项。
所以乘到某一个区间时,只有最高的 \(r-l+1\) 项是有用的了!(因为之后乘的东西是 \(r-l+1\) 次的)
这样时间复杂度就对了,\(O(n\log^2n)\)
同时只用做两次 NTT
了,非常的快
namespace Evaluation{
long long tmp[N];
vector <long long> P[N],Q[N],Qn[N];
long long lenp[N],lenq[N];
void init(long long o,long long l,long long r){
long long len=r-l+1,mid=(l+r)>>1ll;
lenp[o]=len;
lenq[o]=len-1;
P[o].resize(2*lenp[o]+2);
Q[o].resize(2*lenq[o]+2);
if(l==r) return ;
init(o<<1,l,mid);
init(o<<1|1,mid+1,r);
}
void solve1(long long *a,long long o,long long l,long long r){
if(l==r){
P[o]={1,(mod-a[l])%mod};
return ;
}
long long mid=(l+r)>>1;
solve1(a,o<<1,l,mid);
solve1(a,o<<1|1,mid+1,r);
NTT::solve(P[o].data(),P[o<<1].data(),P[o<<1|1].data(),lenp[o<<1],lenp[o<<1|1]);
}
void solve2(long long *s,long long *a,long long o,long long l,long long r){
if(l==r){
s[l]=Q[o][0];
return ;
}
long long mid=(l+r)>>1;
long long len1=mid-l+1,len2=r-mid;
NTT::solve(tmp,Q[o].data(),P[o<<1|1].data(),lenq[o],lenp[o<<1|1]);
for(long long i=0;i<len1;i++) Q[o<<1][i]=tmp[i+len2];
NTT::solve(tmp,Q[o].data(),P[o<<1].data(),lenq[o],lenp[o<<1]);
for(long long i=0;i<len2;i++) Q[o<<1|1][i]=tmp[i+len1];
solve2(s,a,o<<1,l,mid);
solve2(s,a,o<<1|1,mid+1,r);
}
void solve(long long *s,long long *f,long long *a,long long n,long long m){
m=max(n,m);
n=max(n,m);
//我也不知道为什么只有 n=m 时才对……
init(1,1,m);
solve1(a,1,1,m);
poly::Rev(f,f,n);
INV::solve(P[1].data(),P[1].data(),lenp[1]);
NTT::solve(tmp,f,P[1].data(),n,lenp[1]);
for(long long i=0;i<=m;i++) Q[1][i]=tmp[i];
solve2(s,a,1,1,m);
for(long long i=1;i<=m;i++)
s[i]=(f[n]+s[i]*a[i]%mod)%mod;
}
}
11. 多项式快速插值
-
给出 \(n+1\) 个点值 ,求满足这些点值的 \(n\) 次多项式 \(F(x)\)。
说到插值,那肯定是拉格朗日插值
考虑算这个系数,设 \(G(x)= \prod\limits_{i=1}^n (x-x_i)\)
那么这个系数就是 \(t=[\dfrac{G(x)}{x-x_j}](x_i)\)
听起来很有道理,但稍微一想就会发现分子分母都是零……
上洛必达法则:
\(t=[G'(x)](x_i)=G'(x_j)\)
同样考虑分治:
\(O(n\log^2n)\)
namespace Interpolation{
long long tmp[N];
vector <long long> P[N],Q[N];
long long lenp[N],lenq[N];
long long g[N],gv[N];
void init(long long o,long long l,long long r){
long long len=r-l+1,mid=(l+r)>>1ll;
lenp[o]=len;
lenq[o]=len-1;
P[o].resize(2*lenp[o]+2);
Q[o].resize(2*lenq[o]+2);
if(l==r) return ;
init(o<<1,l,mid);
init(o<<1|1,mid+1,r);
}
void solve1(long long *x,long long o,long long l,long long r){
if(l==r){
P[o]={(mod-x[l])%mod,1};
return ;
}
long long mid=(l+r)>>1;
solve1(x,o<<1,l,mid);
solve1(x,o<<1|1,mid+1,r);
NTT::solve(P[o].data(),P[o<<1].data(),P[o<<1|1].data(),lenp[o<<1],lenp[o<<1|1]);
}
void solve2(long long *x,long long *y,long long o,long long l,long long r){
if(l==r){
Q[o]={y[l]*ksm(gv[l],mod-2)%mod};
return ;
}
long long mid=(l+r)>>1;
solve2(x,y,o<<1,l,mid);
solve2(x,y,o<<1|1,mid+1,r);
long long tmp[N];
NTT::solve(Q[o].data(),Q[o<<1].data(),P[o<<1|1].data(),lenq[o<<1],lenp[o<<1|1]);
NTT::solve(tmp,Q[o<<1|1].data(),P[o<<1].data(),lenq[o<<1|1],lenp[o<<1]);
for(long long i=0;i<=lenq[o];i++) Q[o][i]=(Q[o][i]+tmp[i])%mod;
}
void solve(long long *s,long long *x,long long *y,long long n){
init(1,1,n);
solve1(x,1,1,n);
poly::Deriv(g,P[1].data(),lenp[1]);
Evaluation::solve(gv,g,x,lenp[1]-1,n);
solve2(x,y,1,1,n);
for(long long i=0;i<=n;i++) s[i]=Q[1][i];
}
}
12. 下降幂多项式乘法
-
给出下降幂多项式 \(F(x),G(x)\) ,求下降幂多项式 \(H(x)\)。
与 NTT
类似,考虑多项式与点值的变换
不过由于下降幂优美的性质,点值只要选择 0-n 的 \(n+1\) 个数就好了
在 Re:从零开始的生成函数魔法 中的 3.4 有,这里再写一遍
设 \(a_i=F(i)\),\(A(x)\) 是 \(a_n\) 的生成函数
FTT可还行
void ftt(long long *s,long long *f,long long n,long long typ){
static long long A[N];
if(typ==1){
A[1]=1;
NTT::solve(s,inv,f,n,n);
for(long long i=0;i<=n;i++) s[i]=s[i]*mul[i]%mod;
}
else{
for(long long i=0;i<=n;i++) f[i]=f[i]*inv[i]%mod;
for(long long i=0;i<=n;i++) A[i]=inv[i]*ID(i)%mod;
NTT::solve(s,A,f,n,n);
}
}
void FTT(long long *s,long long *f,long long *g,long long n,long long m){
static long long A[N],B[N];
ftt(A,f,n+m,1);
ftt(B,g,n+m,1);
for(long long i=0;i<=n+m;i++) s[i]=A[i]*B[i]%mod;
ftt(s,s,n+m,-1);
}
13. 普通多项式与下降幂多项式互化
-
给出下降幂/普通多项式 \(F(x)\) ,求它的普通/下降幂多项式形式 \(G(x)\)。
普通转下降幂:多点求值再 fft
插值
下降幂转普通: fft
求值再快速插值
void FtoP(long long *s,long long *f,long long n){
static long long A[N],B[N];
ftt(A,f,n,1);
for(long long i=n+1;i>=1;i--) A[i]=A[i-1];
for(long long i=1;i<=n+1;i++) B[i]=i-1;
Interpolation::solve(s,B,A,n+1);
}
void PtoF(long long *s,long long *f,long long n){
static long long A[N],B[N];
for(long long i=1;i<=n+1;i++) B[i]=i-1;
Evaluation::solve(A,f,B,n,n+1);
for(int i=0;i<=n;i++) A[i]=A[i+1];
ftt(s,A,n,-1);
}
14. 任意模数 NTT
多项式卷积,但是任意模数
事实上就用三个模数 998244353 / 1004535809 / 469762049 (原根都是 3)分别做一遍再 CRT
合并起来就好了
由于比较特殊不列在总板子里
namespace EXCRT{
long long solve(long long x1,long long x2,long long x3,long long mod){
long long k1=(x2+mod2-x1)%mod2*ksm(mod1,mod2-2,mod2)%mod2;
long long x4=x1+k1*mod1;
long long k4=(x3+mod3-x4%mod3)%mod3*ksm(mod1*mod2%mod3,mod3-2,mod3)%mod3;
return (x4%mod+k4%mod*mod1%mod*mod2%mod)%mod;
}
}
long long st1[N],st2[N],st3[N];
void MTT(long long *s,long long *a,long long *b,long long n,long long m,long long p){
NTT::solve(st1,a,b,n,m,mod1);
NTT::solve(st2,a,b,n,m,mod2);
NTT::solve(st3,a,b,n,m,mod3);
for(long long i=0;i<=n+m;i++){
s[i]=EXCRT::solve(st1[i],st2[i],st3[i],p);
}
}
15. 分治 FFT(半在线卷积)
假的半在线卷积
- 已知 \(F(x)\),求 \(G(x)\) 使 \(g_n=\sum\limits_{i=1}^n f_ig_{n-i}\),边界条件为 \(g_0=1\)。
知道 \(G(x)\) 的前 \(n\) 项,可以推出第 \(n+1\) 项。
暴力做是 \(O(n^2 \log n)\) 的,但是做了很多无用功,考虑用一些方法优化。
可以类比 cdq 分治:现在的区间 \(G\) 分治成左右两个区间(记作 \(G_l\) 与 \(G_r\)),考虑左边对右边的贡献。
再记 \(F_l'\) 与 \(F'\) 为 \(F(x)\) 的前缀多项式(不太严谨),且 \(\deg F_l'=\deg G_l\),\(\deg F'=\deg G\),即加上 \('\) 就代表被平移到了开头。
显然贡献只有 \(F' * G_l\)。单次 NTT 的时间复杂度为 \(O(n \log n)\),套上分治的总复杂度即为 \(O(n \log^2 n)\)。
真的半在线卷积
- 求 \(F(x)\),\(G(x)\) 使 \(f_n=g_1 \oplus g_2 \oplus \cdots \oplus g_n\),\(g_n=\sum\limits_{i=1}^n f_ig_{n-i}\),边界条件为 \(f_0=g_0=0\),\(f_1=1\)。
现在必须知道 \(f_1,\cdots,f_n\),才能算出 \(f_{n+1}\),没办法再像刚才那样计算。
具体而言,如果左端点不为 \(1\),那 \(F'\) 已经被算出来了,没有问题。
但若左端点为 \(1\),区间 \(G\) 没算出来,\(F'=F\) 自然也不知道。
那现在只考虑左端点为 \(1\) 的情况。这时候 \(F_r'\) 不知道,只能将 \(F_l' * G_l\) 贡献到 \(G_r\) 去。
因此 \(G_r\) 还缺了 \(F_r' * G_l\),没关系,以后再说。
幸运的是,这时候 \(G_r\) 中的第一项系数是正确的,\(F_r\) 的第一项和第二项也计算得出来。
因此现在考虑分治到 \(G_r\) 里面,再分治到区间 \(GG\),可以将 \(FF’ * GG_l\) 贡献到 \(GG_r\)。
但是别忘了,\(G_r\) 还缺了 \(F_r' * G_l\),因此还要再补上 \(FF * GG_l’\)。
每分治到叶子节点就计算一些 \(f_n\),这样就正确了。
时间复杂度仍是 \(O(n \log^2 n)\)。
16. 分治 + NTT
(接下来介绍一些技巧)
- 给定 \(n\) 个一次多项式 \(a_i+b_ix\),求 \(\prod\limits_{i=1}^n (a_i+b_ix)\)。
暴力卷积,是 \(>O(n^2)\) 的……
原因就在于卷着卷着次数会不断变大
想象一下,我们平时做多个数的乘法,没有人会去顺序一个一个地乘吧……
那可以用类似分治的办法,对于每个区间将左右合并起来
时间复杂度是 \(T(n)=2T(n/2)+O(n\log n)=O(n\log^2n)\)
具体的实现可以用 vector
存多项式,空间复杂度是 \(O(n\log n)\) 的
类似的,还可以拓展出分式求和:
- 给定 \(n\) 个一次多项式 \(a_i+b_ix\),求 \(\sum\limits_{i=1}^n \dfrac{1}{a_i+b_ix}\)。
如果直接求逆是 \(O(n^2\log n)\) 的
同上,对于每个区间维护分子和分母,最后再求逆算答案
时间复杂度 \(O(n\log^2n)\)
空间复杂度 \(O(n\log n)\)
总结
在我的认知中多项式主要是为生成函数服务的
- 广告!生成函数与解析组合
还有一些多点求值插值的模板……以后再补吧
学习这些东西其实不需要太多背诵,都是可以推出来的
不过多写写也会熟练点
最后,一整套板子!
#include<bits/stdc++.h>
using namespace std;
const long long N=6e5+5,mod=998244353,inv2=499122177,g=3;
long long mul[N],inv[N],invv[N],pw[N];
long long ksm(long long f,long long x){
long long tot=1;
while(x){
if(x & 1ll) tot=tot*f%mod;
f=f*f%mod;
x>>=1ll;
}
return tot;
}
map <long long,long long> mp;
long long BSGS(long long a,long long p,long long n){
mp.clear();
long long m=ceil(sqrt(p)),tmp=1;
for(long long i=0;i<m;i++) mp[tmp*n%p]=i+1,tmp=tmp*a%p;
long long tmpp=1;
for(long long i=0;i<=m;i++){
if(mp[tmpp]) return i*m-mp[tmpp]+1;
tmpp=tmpp*tmp%p;
}
return -1;
}
long long ID(long long x){
if(x & 1ll) return mod-1;
else return 1;
}
long long S(long long x){
return x*(x-1)/2;
}
long long pr[N];
void pre(){
long long lim=N-5;
mul[0]=inv[0]=1;
for(long long i=1;i<=lim;i++) mul[i]=mul[i-1]*i%mod;
inv[lim]=ksm(mul[lim],mod-2);
for(long long i=lim-1;i>=1;i--) inv[i]=inv[i+1]*(i+1)%mod;
invv[1]=1;
for(int i=2;i<=lim;i++) invv[i]=(mod-mod/i)*invv[mod%i]%mod;
pw[0]=1;
for(long long i=1;i<=lim;i++) pw[i]=pw[i-1]*2%mod;
for(long long mid = 1;mid < N/2;mid <<= 1){
long long Wn = ksm(g, (mod-1)/(mid<<1)); pr[mid] = 1;
for(long long i = 1;i < mid;++ i) pr[mid+i] = 1ll*pr[mid+i-1] * Wn % mod;
}
}
namespace NTT{
long long qmod(long long x){
if(x>=mod) x-=mod;
return x;
}
void qmo(long long &x){x += x >> 31 & mod;}
long long A[N],B[N],rev[N];
long long init(long long n,long long m){
long long lim=0;
while((1ll<<lim)<=n+m) lim++;
for(long long i=0;i<=(1<<lim)-1;i++)
rev[i]=(rev[i>>1ll]>>1ll) | ((i & 1ll)<<(lim-1));
return lim;
}
void ntt(long long *A,long long lim, long long op){
for(long long i = 0;i < lim;++ i)
if(i < rev[i]) swap(A[i], A[rev[i]]);
for(long long mid = 1;mid < lim;mid <<= 1)
for(long long i = 0;i < lim;i += mid<<1)
for(long long j = 0;j < mid;++ j){
long long y = 1ll*A[mid+i+j] * pr[mid+j] % mod;
qmo(A[mid+i+j] = A[i+j] - y);
qmo(A[i+j] += y - mod);
}
if(op==-1){ reverse(A+1, A+lim); long long inv = ksm(lim, mod-2);
for(long long i = 0;i < lim;++ i) A[i] = 1ll*A[i] * inv % mod;}
}
void solve(long long *s,long long* f,long long* g,long long n,long long m){
if(n+m<=100){
long long A[N];
for(long long i=0;i<=n+m;i++) A[i]=0;
for(long long i=0;i<=n;i++){
for(long long j=0;j<=m;j++){
A[i+j]=qmod(A[i+j]+1ll*f[i]*g[j]%mod);
}
}
for(long long i=0;i<=n+m;i++) s[i]=A[i];
return ;
}
m--;
long long lim=init(n,m);
for(long long i=0;i<=n;i++) A[i]=f[i];
for(long long i=0;i<=m;i++) B[i]=g[i];
for(long long i=n+1;i<=(1ll<<lim);i++) A[i]=0;
for(long long i=m+1;i<=(1ll<<lim);i++) B[i]=0;
ntt(A,(1<<lim),1);
ntt(B,(1<<lim),1);
for(long long i=0;i<=(1<<lim)-1;i++) s[i]=1ll*A[i]*B[i]%mod;
ntt(s,(1<<lim),-1);
m++;
s[n+m]=0;
for(long long i=m;i<=n+m;i++) s[i]=(s[i]+1ll*f[i-m]*g[m]%mod)%mod;
}
}
namespace INV{
long long A[N],B[N],S[N];
void solve(long long *s,long long *f,long long n){
S[0]=ksm(f[0],mod-2);
S[1]=0;
long long len;
for(len=2;len<=(n<<1ll);len<<=1ll){
long long lim=len<<1ll;
for(long long i=0;i<len;i++) A[i]=f[i],B[i]=S[i];
for(long long i=len;i<lim;i++) A[i]=B[i]=0;
NTT::init(len,0);
NTT::ntt(A,lim,1);
NTT::ntt(B,lim,1);
for(long long j=0;j<lim;j++)
S[j]=(2*B[j]+mod-A[j]*B[j]%mod*B[j]%mod)%mod;
NTT::ntt(S,lim,-1);
for(long long j=len;j<lim;j++) S[j]=0;
}
for(long long i=0;i<=n;i++) s[i]=S[i];
}
}
namespace poly{
void Add(long long *s,long long *f,long long *g,long long n,long long m){
for(long long i=0;i<=max(n,m);i++) s[i]=(f[i]+g[i])%mod;
}
//加
void Del(long long *s,long long *f,long long *g,long long n,long long m){
for(long long i=0;i<=max(n,m);i++) s[i]=(f[i]+mod-g[i])%mod;
}
//减
void Mult(long long *s,long long *f,long long n,long long k){
for(long long i=0;i<=n;i++) s[i]=(f[i]*k%mod+mod)%mod;
}
//乘
void Rev(long long *s,long long *f,long long n){
for(long long i=0;i<=n/2;i++) swap(f[i],f[n-i]);
}
//翻转
void Deriv(long long *s,long long *f,long long n){
long long A[N];
A[n]=0;
for(long long i=1;i<=n;i++) A[i-1]=f[i]*i%mod;
for(long long i=0;i<=n;i++) s[i]=A[i];
}
//求导
void Limit(long long *s,long long *f,long long n){
long long A[N];
A[0]=0;
for(long long i=0;i<=n-1;i++) A[i+1]=f[i]*ksm(i+1,mod-2)%mod;
for(long long i=0;i<=n;i++) s[i]=A[i];
}
//极限
void Div(long long *q,long long* r,long long *f,long long *g,long long n,long long m){
long long A[N],B[N],C[N],D[N];
for(long long i=0;i<=n;i++) A[i]=C[i]=f[i];
for(long long i=0;i<=m;i++) B[i]=D[i]=g[i];
Rev(A,A,n);
Rev(B,B,m);
INV::solve(B,B,n-m);
NTT::solve(q,A,B,n,n-m);
for(long long i=n-m+1;i<=n+n-m;i++) q[i]=0;
Rev(q,q,n-m);
NTT::solve(D,D,q,m,n-m);
Del(r,C,D,n,n);
}
//带余除法
void Devide_solve(long long *s,long long *g,long long l,long long r){
if(l==r) return ;
long long mid=(l+r)>>1ll;
Devide_solve(s,g,l,mid);
long long lim=r-l+1,len=lim>>1ll;
static long long A[N],B[N];
for(long long i=1;i<=len;i++) A[i]=s[l+i-1];
for(long long i=1;i<=lim;i++) B[i]=g[i];
NTT::solve(A,A,B,len,lim);
for(long long i=len+1;i<=lim;i++) s[i-len+mid]=(s[i-len+mid]+A[i])%mod;
Devide_solve(s,g,mid+1,r);
}
void Devide(long long *s,long long *g,long long n){
long long lim=NTT::init(n,0);
long long m=1ll<<lim;
for(long long i=0;i<=m;i++) s[i]=0;
s[1]=1;
Devide_solve(s,g,1,m);
}
//分治NTT
void Ln(long long *s,long long *f,long long n){
long long A[N],B[N];
long long lim=NTT::init(n,n);
for(long long i=0;i<=n;i++) A[i]=f[i],B[i]=0;
for(long long i=n+1;i<(1ll<<lim);i++) A[i]=B[i]=0;
Deriv(B,A,n);
INV::solve(A,A,n);
NTT::solve(A,A,B,n,n);
Limit(A,A,n);
for(long long i=0;i<=n;i++) s[i]=A[i];
}
//对数
vector <long long> P[N];
void solve(long long n,long long o,long long l,long long r){
long long len=r-l+1;
P[o].resize(4*len+2);
if(l==r){
if(l<=n) P[o]={l-1,1};//在这里改
else P[o]={1};
return ;
}
long long mid=(l+r)>>1;
solve(n,o<<1,l,mid);
solve(n,o<<1|1,mid+1,r);
NTT::solve(P[o].data(),P[o<<1].data(),P[o<<1|1].data(),len>>1,len>>1);
}
//分治+NTT
void Sqrt(long long *s,long long *f,long long n){
long long A[N],B[N],C[N],S[N];
long long p=BSGS(3,998244353,f[0]);
S[0]=ksm(3,p/2);
S[1]=0;
long long len;
for(len=1;len<=(n<<1ll);len<<=1ll){
long long lim=len<<1ll;
for(long long i=0;i<len;i++) A[i]=f[i],B[i]=S[i];
for(long long i=len;i<lim;i++) A[i]=B[i]=C[i]=0;
INV::solve(C,B,len);
NTT::init(len,0);
NTT::ntt(A,lim,1);
NTT::ntt(C,lim,1);
for(long long j=0;j<lim;j++) S[j]=A[j]*C[j]%mod;
NTT::ntt(S,lim,-1);
for(long long i=0;i<lim;i++) S[i]=(S[i]+B[i])%mod*inv2%mod;
for(long long j=len;j<lim;j++) S[j]=0;
}
for(long long i=0;i<=n;i++) s[i]=S[i];
}
//开根
void Exp(long long *s,long long *f,long long n){
long long A[N],B[N],C[N],S[N];
S[0]=1;
S[1]=0;
long long len;
for(len=2;len<=(n<<1ll);len<<=1ll){
long long lim=len<<1ll;
for(long long i=0;i<len;i++) A[i]=f[i],B[i]=S[i];
for(long long i=len;i<lim;i++) A[i]=B[i]=C[i]=0;
Ln(C,B,len-1);
for(long long i=0;i<len;i++) C[i]=(mod-C[i]+A[i])%mod;
C[0]=(C[0]+1)%mod;
NTT::init(len,0);
NTT::ntt(B,lim,1);
NTT::ntt(C,lim,1);
for(long long j=0;j<lim;j++) S[j]=B[j]*C[j]%mod;
NTT::ntt(S,lim,-1);
for(long long j=len;j<lim;j++) S[j]=0;
}
for(long long i=0;i<=n;i++) s[i]=S[i];
}
//指数
void Pow(long long *s,long long *f,long long n,long long k1,long long k2,long long lenk){
long long A[N];
long long lim=NTT::init(n,n);
long long t=0;
while(!f[t]) t++;
if(lenk>5 && t){
for(long long i=0;i<=n;i++) s[i]=0;
return ;
}
long long p=ksm(f[t],k2),q=ksm(f[t],mod-2);
n-=t;
for(long long i=0;i<=n;i++) A[i]=f[i+t]*q%mod;
for(long long i=n+1;i<(1ll<<lim);i++) A[i]=0;
Ln(A,A,n);
Mult(A,A,n,k1);
Exp(A,A,n);
for(long long i=0;i<=n+t;i++) s[i]=0;
for(long long i=0;i+t*k1<=n+t;i++)
s[i+t*k1]=A[i]*p%mod;
}
//快速幂
}
参考资料: