转置原理

基本原理

基础公式

\(V=E_1E_2\cdots E_n\to V^T=E_n^TE_{n-1}^T\cdots E_1^T\)

\(E\)是基础矩阵,分为

  • 让某一位\(i\)乘上\(k\),转置后效果一样
  • 让某一位\(i\)乘上\(k\)加到另一位\(j\)上,转置后是把\(j\)乘上\(k\)加到\(i\)

转置fft

\(A\times B=C\),即\(C_{i+j}+=A_i\cdot B_j\)

如果把\(A\)当作常数,\(B\)当作输入的\(f\),就是把\(j\)这一位乘上\(A_i\)再加到\(i+j\)这位上

那么转置后就是把\(i+j\)这一位乘上\(A_i\)加到\(j\)这一位上

\(B_j+=A_i\cdot C_{i+j}\),或者\(B_{i-j}+=C_i\cdot A_j\)

也就是说加法卷积转置一下就变成了一个减法卷积,记为\(C\times^TA=B\)

具体实现为翻转\(A\)(\(A_0=A_{len}\))做一次多项式乘法,然后\(B_i=B_{i+len_A}\)

不难发现(长度为\(n+m\)的多项式\()\times^T\)(长度为\(n\)的多项式)\(=\)(长度为\(m\)的多项式),而且没有交换律

具体作用

题目要求\(Vf\),直接求不好求,但是很容易想到如何求\(V^Tf\)。所以考虑求出\(f\to V^Tf\)这个过程在干什么,再把这个过程转置一下(就是把过程完全倒过来)就能得到\(f\to Vf\)的做法

实践

多项式多点求值

可以看作是\(\begin{bmatrix}1&a_0&a_0^2&\cdots\\1&a_1&a_1^2&\cdots\\1&a_2&a_2^2&\cdots\\\cdots\\1&a_{n-1}&a_{n-1}^2&\cdots\end{bmatrix}\times f\)(直接展开就是所求的式子)

设左边的为\(V\),考虑如何求\(V^Tf\)

即为\([x^k]\sum_{i=0}^{n-1}f_ia_i^k\),也就是求多项式\(\sum_{i=0}^{n-1}\frac{f_i}{1-a_ix}\)

这个式子用分治fft求出,具体过程是:

  • 当前在\((l,r)\)
  • 递归处理\((l,mid)\) \((mid+1,r)\)
  • 设左边为\(\frac{B_l}{A_l}\),右边为\(\frac{B_r}{A_r}\)
  • 那么当前区间分子为\(B_l\times A_r+B_r\times A_l\),分母为\({A_l\times A_r}\)

最后答案为\(ans=B_{[0,n-1]}\times A_{[0,n-1]}^{-1}\)


那么如何求\(f\to Vf\)呢,就是把过程完全反过来(注意\(A\times B=C\)倒过来就变成了\(C\times^T B=A\),假设\(A\)为变量,\(B\)为常数)

由于\(B\)是一个常数,或者说是一个与\(f\)无关的式子,所以反过来没有变化

先求出\(B_{[0,n-1]}=f\times^T A_{[0,n-1]}^{-1}\)(这里可以理解为往\(f\)后面添\(0\)使\(f\)长度为\(2n\)

\(A_{[0,n-1]}\)可以用一次分治fft求出

然后用分治fft:

  • 当前在\((l,r)\),当前分子为\(B\)
  • 求出左边的分子为\(B\times^T A_r\),右边的分子为\(B\times^T A_l\)
  • 递归处理\((l,mid)\) \((mid+1,r)\)

最后递归到叶节点\((i,i)\)时,\(ans_i=[x^0]B_{[i,i]}\)


仔细观察的话会发现,就是完全把之前的过程反了过来(包括顺序)

然后这种方法只用到了一次多项式求逆(最开始),和分治fft,显然常数比一般要用取膜的做法小,优化得好甚至可以5e5多点求值(膜$E)。

#include <bits/stdc++.h>
using namespace std;

template<typename T>
inline void Read(T &n){
	char ch; bool flag=false;
	while(!isdigit(ch=getchar()))if(ch=='-')flag=true;
	for(n=ch^48;isdigit(ch=getchar());n=(n<<1)+(n<<3)+(ch^48));
	if(flag)n=-n;
}

const int MAXN = 64005;
const int MOD = 998244353;
const int G = 3;

inline int inc(int a, int b){
	a += b;
	if(a>=MOD) a -= MOD;
	return a;
}

inline void iinc(int &a, int b){a = inc(a,b);}

inline int dec(int a, int b){
	a -= b;
	if(a<0) a += MOD;
	return a;
}

inline void ddec(int &a, int b){a = dec(a,b);}

inline int ksm(int base, int k=MOD-2){
	int res = 1;
	while(k){
		if(k&1)
			res = 1ll*res*base%MOD;
		base = 1ll*base*base%MOD;
		k >>= 1;
	}
	return res;
}

typedef vector<int> poly;

int tr[MAXN<<2], wn[MAXN<<2];
inline int prework(int n){
	int len = 1; while(len<=n) len<<=1;
	for(register int i=0; i<=len; i++) tr[i] = (tr[i>>1]>>1)|((i&1)?len>>1:0);
	wn[0] = 1; wn[1] = ksm(G,(MOD-1)/len); for(register int i=2; i<=len; i++) wn[i] = 1ll*wn[i-1]*wn[1]%MOD;
	return len;
}

poly f, g;
inline void ntt(poly &f, int n, int flag){
	for(register int i=0; i<n; i++) if(i<tr[i]) swap(f[i],f[tr[i]]);
	for(register int len=2; len<=n; len<<=1){
		int base = flag*n/len;
		for(register int l=0; l<n; l+=len){
			int now = flag==1?0:n;
			for(register int i=l; i<l+len/2; i++){
				int tmp = 1ll*f[i+len/2]*wn[now]%MOD;
				f[i+len/2] = dec(f[i],tmp);
				f[i] = inc(f[i],tmp);
				now += base;
			}
		}
	}
	if(flag==-1){
		int invn = ksm(n);
		for(register int i=0; i<n; i++) f[i] = 1ll*f[i]*invn%MOD;
	}
}

void poly_inv(poly f, poly &res, int n){
	static poly tmp;
	if(n==1) return res.resize(1), res[0] = ksm(f[0]), void();
	poly_inv(f,res,n+1>>1);
	int len = prework(n<<1);
	tmp.resize(len); res.resize(len);
	for(register int i=0; i<n; i++) tmp[i] = f[i];
	ntt(tmp,len,1); ntt(res,len,1);
	for(register int i=0; i<len; i++) res[i] = dec(inc(res[i],res[i]),1ll*tmp[i]*res[i]%MOD*res[i]%MOD);
	ntt(res,len,-1);
	tmp.resize(0); res.resize(n);
}

int a[MAXN];

int n, m;

poly A[MAXN<<2], B[MAXN<<2], invA, F;
inline int lc(int x){return x<<1;}
inline int rc(int x){return x<<1|1;}

void solve1(int x, int l, int r){
	if(l==r) return A[x].resize(2), A[x][0] = 1, A[x][1] = dec(0,a[l]), void();
	int mid = l+r >> 1;
	solve1(lc(x),l,mid); solve1(rc(x),mid+1,r);
	int num = A[lc(x)].size()+A[rc(x)].size()-2;
	int len = prework(r-l+1);
	f = A[lc(x)]; g = A[rc(x)];
	f.resize(len); g.resize(len); A[x].resize(len);
	ntt(f,len,1); ntt(g,len,1);
	for(register int i=0; i<len; i++) A[x][i] = 1ll*f[i]*g[i]%MOD;
	ntt(A[x],len,-1);
	f.resize(0); g.resize(0); A[x].resize(r-l+2);
}

int ans[MAXN];
void solve2(int x, int l, int r){
	if(l==r) return ans[l] = B[x][0], void();
	int mid = l+r >> 1;

	int len = prework(r-l+1<<1);
	reverse(A[lc(x)].begin(),A[lc(x)].end());
	reverse(A[rc(x)].begin(),A[rc(x)].end());
	A[lc(x)].resize(len); A[rc(x)].resize(len); B[x].resize(len); B[lc(x)].resize(len); B[rc(x)].resize(len);
	ntt(A[lc(x)],len,1); ntt(A[rc(x)],len,1), ntt(B[x],len,1);
	for(register int i=0; i<len; i++) B[lc(x)][i] = 1ll*A[rc(x)][i]*B[x][i]%MOD, B[rc(x)][i] = 1ll*A[lc(x)][i]*B[x][i]%MOD;
	ntt(B[lc(x)],len,-1); ntt(B[rc(x)],len,-1);
	for(register int i=0; i<=mid-l+1; i++) B[lc(x)][i] = B[lc(x)][i+r-mid]; B[lc(x)].resize(mid-l+2);
	for(register int i=0; i<=r-mid; i++) B[rc(x)][i] = B[rc(x)][i+mid-l+1]; B[rc(x)].resize(r-mid+1);
	solve2(lc(x),l,mid); solve2(rc(x),mid+1,r);
}

int main(){
	Read(n); Read(m);
	F.resize(n+1);
	for(register int i=0; i<=n; i++) Read(F[i]);
	for(register int i=0; i<m; i++) Read(a[i]);
	n = max(n,m-1);
	solve1(1,0,n); poly_inv(A[1],invA,n+1);
	reverse(invA.begin(),invA.end());
	int len = prework(n<<1);
	F.resize(len); invA.resize(len); B[1].resize(len);
	ntt(F,len,1); ntt(invA,len,1);
	for(register int i=0; i<len; i++) B[1][i] = 1ll*F[i]*invA[i]%MOD;
	ntt(B[1],len,-1); for(register int i=0; i<=n; i++) B[1][i] = B[1][i+n]; B[1].resize(n+1);
	solve2(1,0,n);
	for(register int i=0; i<m; i++) printf("%d\n",ans[i]);
	return 0;
}
```**
posted @ 2021-06-26 14:20  oisdoaiu  阅读(724)  评论(0编辑  收藏  举报