转置原理与多项式多点求值
终于学转置原理了,之前一直听 zhy 糊多项式题不知道他在讲写啥。
自己的多项式水平长期停留在多项式除法,直到今天做互测时被迫学了怎么去多点求值。正式比赛大概率不考(吧?)所以学来娱乐一下。
普通多点求值算法
思想很妙,效率很逊。代码不写了因为我连多项式取模都忘了怎么写了。
考虑类似 CRT 和拉插的思想,对于点值集合 \(P\) 构造多项式 \(M_P(x)=\prod_{c\in P}(x-c)\)。这样就有了 \(\forall x\in P,M_P(x)=0\)。
然后考虑分治计算贡献,每次将点值均匀分成两个集合 \(P_0,P_1\),对于当前多项式 \(F\),多项式取模一下将其表示为 \(F=Q_0 M_{P_0}+F_0\) 和 \(F=Q_1 M_{P_1}+F_1\)。
那么你代入 \(P_0\) 中的点值时商的贡献就被消掉了,有 \(F(x)=F_0(x)\),这样 \(P,F\) 的规模双双减半。于是你往下递归做就可以了。\(M_P\) 可以在分治过程中递归维护出来。
每一层都带多项式取模(求逆)的 \(\log\),总复杂度 \(O(n\log^2 n)\),而且常数上天。
快速插值算法
多点求值本质上让我们乘上任意一个范德蒙德矩阵,而多项式操作基本上都是对于程序中变量的线性变换。这启发我们插值作为求值的逆算法也可以做到同等复杂度。
快速插值我们考虑对着拉插公式大力优化:\(F=\sum y_i \prod_{i\neq j} \frac{x-x_j}{x_i-x_j}\)。
考虑那个分母,这个东西里面似乎就是一个 \(M_{P/\{x_i\}}(x_i)\),也就是 \(M_P(x_i)\) 除以一个 \(x_i-x_i\),发现关键在零除零,那你发动技能“洛神”对里面大力洛必达一下,就可以得到 \(M_{P/\{x_i\}}(x_i)=M_P'(x_i)\)。也就是说你直接分治 NTT 出 \(M_P\) 之后求个导,然后再套多点求值就可以了!
剩下的相当于要计算 \(\sum \frac{y_i}{M_{P/\{x_i\}}(x_i)} \prod_{i\neq j} (x-x_j)\),直接上经典的缺一分治,也是分治 NTT,做完了!复杂度同多点求值,常数更上一层楼!
接下来是发病时间:
啊你看这个转置原理啊,就是说多点求值作为一个线性算法时可以分解成初等行矩阵然后去转置它(\((AB)^T=B^TA^T\))。那你也可以分解成初等行矩阵之后求逆它啊(\((AB)^{-1}=B^{-1}A^{-1}\))。
那你可以把多点求值中的所有变量都当成向量中的变量,把算法全部描述成乘上初等行变换矩阵,然后 reverse
一下这个操作序列,再一个一个还原回去就行了。
不要提常数的事,这个做法时空都是 \(O(n\log^2 n)\) 的毫无实际意义。
转置原理
看的王总博客
上面的技术在 EI 一行人引入转置原理之后就显得过时了整个多项式内容都过时了。
多项式操作经常是线性变换,你仔细读一下 DFT 的代码就会发现你唯一做了的事要么就是将一个数的若干倍加到另一个数上,或者说交换两个数。
对于这种“线性算法”,它总能够被描述成乘上一个矩阵的样子 \(\vec{u}=A\vec{v}\),而且我们知道了如果它是可以快速做的,那么在同等的时间复杂度内它的逆变换时可以快速做的,具体看上文“发病时间“。同理它的转置算法,也就是求 \(A^T \vec{v}\) 也是可以快速做的。
卷积算法是可以转置的,这可以看王总博客,讲得很详细。它转置之后是一个减法卷积。注意一下转置要求乘法必须有一方是常量,对多项式也是这样要求的。
设 \(f_i=\sum_{j=0}^i g_jh_{i-j}\) 中 \(h\) 数组是常量,那么对于卷积转置实际上算的就是:
(其实从王总博客里还知道一个减法卷积卡常小技巧,反正你前若干位是不要的所以可以把 NTT 长度开小一点让它去溢出。)
那 DFT 呢?搞笑了,它的转置就是它自己。
直接讲讲多点求值吧。多点求值就是直接乘上任意一种范德蒙德矩阵,相当于:
你把它转置,猜猜它是啥:
是你,带权等幂和问题!我们考虑先如何解决这个问题,写出 GF:
这里有个求和,看起来比较棘手。我们不妨考虑通分,用分治 NTT 去处理:
用分治 NTT 直接维护出 \(A,B\) 就可以做了。
把这个算法转置一下,成啥样了呢?对于这种大型算法的转置,我们可以把它分解成若干个线性算法的步骤,让它们的输入输出首尾相接,然后依次转置。比如这里把这个算法分解成了若干次多项式加法和乘法。
有人问,求逆咋转置呢?这不是线性算法啊?但是你求逆的东西是分母,分母只与 \(x_i\) 有关,而 \(x_i\) 你是写在矩阵里而不是向量里的,也就是说只跟 \(x_i\) 有关的东西在转置意义下其实是常量。
注意一下所谓转置一定不要把常量运算给转置了。这个“常量”不一定是一般讲的一个数,跟输入输出无关的多项式也要不转。
那么转置之后就变成了一个反着的分治,每次把多项式往底层递归而不是往上合并计算就可以了。注意这个过程中的多项式乘法是转置多项式乘法。
代码正在努力实现中。
多项式求逆
这个 sb 在实现代码的时候突然发现自己连多项式怎么求逆都忘了。
这下记住了吧??!!
代码
终于写完了。
upd: 又改动了一点细节,还写了快速插值的板子。
下面这个代码是快速插值,当然也包含了多点求值。
#include <cstdio>
#include <vector>
#include <cassert>
#include <algorithm>
#define lc (p<<1)
#define rc (p<<1|1)
#define ALL(p) p.begin(),p.end()
#define IL inline
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<22,stdin)),p1==p2?EOF:*p1++)
using namespace std;
char buf[1<<22],*p1=buf,*p2=buf;
int read(){
char c=getchar();int x=0;
while(c<48||c>57) c=getchar();
do x=(x<<1)+(x<<3)+(c^48),c=getchar();
while(c>=48&&c<=57);
return x;
}
typedef vector<int> vi;
const int P=998244353,N=100003;
typedef long long ll;
IL int qp(int a,int b=P-2){
int res=1;
while(b){
if(b&1) res=(ll)res*a%P;
a=(ll)a*a%P;b>>=1;
}
return res;
}
int len,ilen,bt;
int rev[1<<20],cw[1<<20|1];
IL void init(int _len){ // mod x^len
len=1;bt=-1;
while(len<_len) len<<=1,++bt;
int w=qp(3,(P-1)>>(bt+1));
cw[0]=cw[len]=1;
for(int i=1;i<len;++i){
cw[i]=(ll)cw[i-1]*w%P;
rev[i]=(rev[i>>1]>>1)|((i&1)<<bt);
}
ilen=qp(len);
}
struct poly{
vi f;
poly():f(){}
poly(int Len):f(Len){}
poly(initializer_list<int> Init):f(Init){}
IL void NTT(){
f.resize(len,0);
for(int i=1;i<len;++i) if(rev[i]<i) swap(f[rev[i]],f[i]);
for(int i=1,tt=len>>1;i<len;i<<=1,tt>>=1)
for(int j=0;j<len;j+=(i<<1))
for(int k=j,t=0;k<(j|i);++k,t+=tt){
int x=f[k],y=(ll)f[k|i]*cw[t]%P;
if((f[k]=x+y)>=P) f[k]-=P;
if((f[k|i]=x-y)<0) f[k|i]+=P;
}
}
IL void INTT(){
for(int i=1;i<len;++i) if(rev[i]<i) swap(f[rev[i]],f[i]);
for(int i=1,tt=len>>1;i<len;i<<=1,tt>>=1)
for(int j=0;j<len;j+=(i<<1))
for(int k=j,t=len;k<(j|i);++k,t-=tt){
int x=f[k],y=(ll)f[k|i]*cw[t]%P;
if((f[k]=x+y)>=P) f[k]-=P;
if((f[k|i]=x-y)<0) f[k|i]+=P;
}
for(int i=0;i<len;++i) f[i]=(ll)f[i]*ilen%P;
}
IL void reduce(){while(!f.empty()&&!f.back()) f.pop_back();}
IL void trunc(int lim){
if(lim<int(f.size())) f.erase(f.begin()+lim,f.end());
}
IL poly inv(int lim){ // mod x^lim
assert(f[0]);
poly cur({qp(f[0])});
for(int t=1;t<lim;t<<=1){
poly ff(t<<2);
copy(f.begin(),f.begin()+min(t<<1,int(f.size())),ff.f.begin());
init(t<<2);ff.NTT();cur.NTT();
poly tmp(len);
for(int i=0;i<len;++i){
tmp.f[i]=(2ll-(ll)cur.f[i]*ff.f[i])%P*cur.f[i]%P;
if(tmp.f[i]<0) tmp.f[i]+=P;
}
tmp.INTT();
cur.f.swap(tmp.f);
cur.trunc(t<<1);
}
cur.trunc(lim);
return cur;
}
IL void plus(poly A,poly B){
int mx=max(A.f.size(),B.f.size());
A.f.resize(mx,0);B.f.resize(mx,0);
f.resize(mx);
for(int i=0;i<mx;++i){
f[i]=A.f[i]+B.f[i];
if(f[i]>=P) f[i]-=P;
}
}
IL void prod(poly A,poly B){
init(A.f.size()+B.f.size()-1);A.NTT();B.NTT();
f.resize(len);
for(int i=0;i<len;++i) f[i]=(ll)A.f[i]*B.f[i]%P;
INTT();reduce();
}
IL void prodT(poly A,poly B){
int an=A.f.size()-1,bn=B.f.size()-1;
reverse(ALL(B.f));prod(A,B);
for(int i=0;i<=an;++i) f[i]=f[i+bn];
trunc(an+1);
}
IL int calc(int t){
int pw=1,res=0;
for(int x:f){
res=(res+(ll)pw*x)%P;
if(res>=P) res-=P;
pw=(ll)pw*t%P;
}
return res;
}
IL poly deriv(){
int n=f.size();
poly D(n-1);
for(int i=1;i<n;++i) D.f[i-1]=(ll)f[i]*i%P;
return D;
}
}F;
int n;
int px[N],py[N],res[N];
namespace multieval{
poly G[N<<2],H[N<<2];
void calc(int p,int l,int r){
if(l==r){G[p]=poly({1,(P-px[r])%P});return;}
int mid=(l+r)>>1;
calc(lc,l,mid);
calc(rc,mid+1,r);
G[p].prod(G[lc],G[rc]);
}
void solve(int p,int l,int r){
H[p].trunc(r-l+1);
if(l==r){res[r]=H[p].f[0];return;}
int mid=(l+r)>>1;
G[p].f.clear();G[p].f.shrink_to_fit();
H[lc].prodT(H[p],G[rc]);
H[rc].prodT(H[p],G[lc]);
H[p].f.clear();H[p].f.shrink_to_fit();
solve(lc,l,mid);
solve(rc,mid+1,r);
}
void sol(){
calc(1,1,n);
H[1].prodT(F,G[1].inv(n+1));
solve(1,1,n);
}
}
poly T[N<<2];
void eval(int p,int l,int r){
if(l==r){T[p]=poly({(P-px[r])%P,1});return;}
int mid=(l+r)>>1;
eval(lc,l,mid);
eval(rc,mid+1,r);
T[p].prod(T[lc],T[rc]);
}
poly proc(int p,int l,int r){
if(l==r) return poly({py[r]});
int mid=(l+r)>>1;
poly L=proc(lc,l,mid);
poly R=proc(rc,mid+1,r);
L.prod(L,T[rc]);
R.prod(R,T[lc]);
poly M;M.plus(L,R);
return M;
}
int main(){
n=read();
for(int i=1;i<=n;++i) px[i]=read(),py[i]=read();
eval(1,1,n);
F=T[1].deriv();
multieval::sol();
for(int i=1;i<=n;++i) py[i]=(ll)py[i]*qp(res[i])%P;
poly R=proc(1,1,n);
R.f.resize(n-1,0);
for(int i=0;i<n;++i) printf("%d ",R.f[i]);
putchar('\n');
return 0;
}