转置原理
基本原理
基础公式
\(V=E_1E_2\cdots E_n\to V^T=E_n^TE_{n-1}^T\cdots E_1^T\)
\(E\)是基础矩阵,分为
- 让某一位\(i\)乘上\(k\),转置后效果一样
- 让某一位\(i\)乘上\(k\)加到另一位\(j\)上,转置后是把\(j\)乘上\(k\)加到\(i\)上
转置fft
\(A\times B=C\),即\(C_{i+j}+=A_i\cdot B_j\)
如果把\(A\)当作常数,\(B\)当作输入的\(f\),就是把\(j\)这一位乘上\(A_i\)再加到\(i+j\)这位上
那么转置后就是把\(i+j\)这一位乘上\(A_i\)加到\(j\)这一位上
即\(B_j+=A_i\cdot C_{i+j}\),或者\(B_{i-j}+=C_i\cdot A_j\)
也就是说加法卷积转置一下就变成了一个减法卷积,记为\(C\times^TA=B\)
具体实现为翻转\(A\)(\(A_0=A_{len}\))做一次多项式乘法,然后\(B_i=B_{i+len_A}\)
不难发现(长度为\(n+m\)的多项式\()\times^T\)(长度为\(n\)的多项式)\(=\)(长度为\(m\)的多项式),而且没有交换律
具体作用
题目要求\(Vf\),直接求不好求,但是很容易想到如何求\(V^Tf\)。所以考虑求出\(f\to V^Tf\)这个过程在干什么,再把这个过程转置一下(就是把过程完全倒过来)就能得到\(f\to Vf\)的做法
实践
多项式多点求值
可以看作是\(\begin{bmatrix}1&a_0&a_0^2&\cdots\\1&a_1&a_1^2&\cdots\\1&a_2&a_2^2&\cdots\\\cdots\\1&a_{n-1}&a_{n-1}^2&\cdots\end{bmatrix}\times f\)(直接展开就是所求的式子)
设左边的为\(V\),考虑如何求\(V^Tf\)
即为\([x^k]\sum_{i=0}^{n-1}f_ia_i^k\),也就是求多项式\(\sum_{i=0}^{n-1}\frac{f_i}{1-a_ix}\)
这个式子用分治fft求出,具体过程是:
- 当前在\((l,r)\)
- 递归处理\((l,mid)\) \((mid+1,r)\)
- 设左边为\(\frac{B_l}{A_l}\),右边为\(\frac{B_r}{A_r}\)
- 那么当前区间分子为\(B_l\times A_r+B_r\times A_l\),分母为\({A_l\times A_r}\)
最后答案为\(ans=B_{[0,n-1]}\times A_{[0,n-1]}^{-1}\)
那么如何求\(f\to Vf\)呢,就是把过程完全反过来(注意\(A\times B=C\)倒过来就变成了\(C\times^T B=A\),假设\(A\)为变量,\(B\)为常数)
由于\(B\)是一个常数,或者说是一个与\(f\)无关的式子,所以反过来没有变化
先求出\(B_{[0,n-1]}=f\times^T A_{[0,n-1]}^{-1}\)(这里可以理解为往\(f\)后面添\(0\)使\(f\)长度为\(2n\))
\(A_{[0,n-1]}\)可以用一次分治fft求出
然后用分治fft:
- 当前在\((l,r)\),当前分子为\(B\)
- 求出左边的分子为\(B\times^T A_r\),右边的分子为\(B\times^T A_l\)
- 递归处理\((l,mid)\) \((mid+1,r)\)
最后递归到叶节点\((i,i)\)时,\(ans_i=[x^0]B_{[i,i]}\)
仔细观察的话会发现,就是完全把之前的过程反了过来(包括顺序)
然后这种方法只用到了一次多项式求逆(最开始),和分治fft,显然常数比一般要用取膜的做法小,优化得好甚至可以5e5多点求值(膜$E)。
#include <bits/stdc++.h>
using namespace std;
template<typename T>
inline void Read(T &n){
char ch; bool flag=false;
while(!isdigit(ch=getchar()))if(ch=='-')flag=true;
for(n=ch^48;isdigit(ch=getchar());n=(n<<1)+(n<<3)+(ch^48));
if(flag)n=-n;
}
const int MAXN = 64005;
const int MOD = 998244353;
const int G = 3;
inline int inc(int a, int b){
a += b;
if(a>=MOD) a -= MOD;
return a;
}
inline void iinc(int &a, int b){a = inc(a,b);}
inline int dec(int a, int b){
a -= b;
if(a<0) a += MOD;
return a;
}
inline void ddec(int &a, int b){a = dec(a,b);}
inline int ksm(int base, int k=MOD-2){
int res = 1;
while(k){
if(k&1)
res = 1ll*res*base%MOD;
base = 1ll*base*base%MOD;
k >>= 1;
}
return res;
}
typedef vector<int> poly;
int tr[MAXN<<2], wn[MAXN<<2];
inline int prework(int n){
int len = 1; while(len<=n) len<<=1;
for(register int i=0; i<=len; i++) tr[i] = (tr[i>>1]>>1)|((i&1)?len>>1:0);
wn[0] = 1; wn[1] = ksm(G,(MOD-1)/len); for(register int i=2; i<=len; i++) wn[i] = 1ll*wn[i-1]*wn[1]%MOD;
return len;
}
poly f, g;
inline void ntt(poly &f, int n, int flag){
for(register int i=0; i<n; i++) if(i<tr[i]) swap(f[i],f[tr[i]]);
for(register int len=2; len<=n; len<<=1){
int base = flag*n/len;
for(register int l=0; l<n; l+=len){
int now = flag==1?0:n;
for(register int i=l; i<l+len/2; i++){
int tmp = 1ll*f[i+len/2]*wn[now]%MOD;
f[i+len/2] = dec(f[i],tmp);
f[i] = inc(f[i],tmp);
now += base;
}
}
}
if(flag==-1){
int invn = ksm(n);
for(register int i=0; i<n; i++) f[i] = 1ll*f[i]*invn%MOD;
}
}
void poly_inv(poly f, poly &res, int n){
static poly tmp;
if(n==1) return res.resize(1), res[0] = ksm(f[0]), void();
poly_inv(f,res,n+1>>1);
int len = prework(n<<1);
tmp.resize(len); res.resize(len);
for(register int i=0; i<n; i++) tmp[i] = f[i];
ntt(tmp,len,1); ntt(res,len,1);
for(register int i=0; i<len; i++) res[i] = dec(inc(res[i],res[i]),1ll*tmp[i]*res[i]%MOD*res[i]%MOD);
ntt(res,len,-1);
tmp.resize(0); res.resize(n);
}
int a[MAXN];
int n, m;
poly A[MAXN<<2], B[MAXN<<2], invA, F;
inline int lc(int x){return x<<1;}
inline int rc(int x){return x<<1|1;}
void solve1(int x, int l, int r){
if(l==r) return A[x].resize(2), A[x][0] = 1, A[x][1] = dec(0,a[l]), void();
int mid = l+r >> 1;
solve1(lc(x),l,mid); solve1(rc(x),mid+1,r);
int num = A[lc(x)].size()+A[rc(x)].size()-2;
int len = prework(r-l+1);
f = A[lc(x)]; g = A[rc(x)];
f.resize(len); g.resize(len); A[x].resize(len);
ntt(f,len,1); ntt(g,len,1);
for(register int i=0; i<len; i++) A[x][i] = 1ll*f[i]*g[i]%MOD;
ntt(A[x],len,-1);
f.resize(0); g.resize(0); A[x].resize(r-l+2);
}
int ans[MAXN];
void solve2(int x, int l, int r){
if(l==r) return ans[l] = B[x][0], void();
int mid = l+r >> 1;
int len = prework(r-l+1<<1);
reverse(A[lc(x)].begin(),A[lc(x)].end());
reverse(A[rc(x)].begin(),A[rc(x)].end());
A[lc(x)].resize(len); A[rc(x)].resize(len); B[x].resize(len); B[lc(x)].resize(len); B[rc(x)].resize(len);
ntt(A[lc(x)],len,1); ntt(A[rc(x)],len,1), ntt(B[x],len,1);
for(register int i=0; i<len; i++) B[lc(x)][i] = 1ll*A[rc(x)][i]*B[x][i]%MOD, B[rc(x)][i] = 1ll*A[lc(x)][i]*B[x][i]%MOD;
ntt(B[lc(x)],len,-1); ntt(B[rc(x)],len,-1);
for(register int i=0; i<=mid-l+1; i++) B[lc(x)][i] = B[lc(x)][i+r-mid]; B[lc(x)].resize(mid-l+2);
for(register int i=0; i<=r-mid; i++) B[rc(x)][i] = B[rc(x)][i+mid-l+1]; B[rc(x)].resize(r-mid+1);
solve2(lc(x),l,mid); solve2(rc(x),mid+1,r);
}
int main(){
Read(n); Read(m);
F.resize(n+1);
for(register int i=0; i<=n; i++) Read(F[i]);
for(register int i=0; i<m; i++) Read(a[i]);
n = max(n,m-1);
solve1(1,0,n); poly_inv(A[1],invA,n+1);
reverse(invA.begin(),invA.end());
int len = prework(n<<1);
F.resize(len); invA.resize(len); B[1].resize(len);
ntt(F,len,1); ntt(invA,len,1);
for(register int i=0; i<len; i++) B[1][i] = 1ll*F[i]*invA[i]%MOD;
ntt(B[1],len,-1); for(register int i=0; i<=n; i++) B[1][i] = B[1][i+n]; B[1].resize(n+1);
solve2(1,0,n);
for(register int i=0; i<m; i++) printf("%d\n",ans[i]);
return 0;
}
```**