MTT学习小记

求p是1e9级别,n是1e5级别的fft

首先拆系数拆成aw+b的形式,那么求的是(aw+b)(cw+d)=acw^2+(ad+bc)w+bd,变成求ac,ad,bc,bd的卷积

构造\(P=(a+bi)(c+di)=(ac-bd)+(ad+bc)i\)\(Q=(a-bi)(c+di)=(ac+bd)+(ad-bc)i\),求出PQ之后解方程可以解出来

观察a+bi和a-bi是共轭的,根据共轭复数的性质(ab)'=a'b',求出a+bi的点值之后可以直接得到a-bi的点值

具体来说,\([x^i]DFT_{a+bi}(x)=[x^{N-i}]DFT_{a-bi}(x)\)(注意是N不是N-1,因为1和N-1共轭)

其实就是i和N-i的单位根也共轭

(a+bi和a-bi共轭,a-bi的wj与a+bi的w(n-j)共轭,所以a-bi的点值j和a+bi的点值n-j共轭,不是相等)

UPD:可以画图理解,初始向量共轭(对称),wj,wn-j每次转圈方向相反最后结果也共轭

c+di直接求,再对PQ用两次IDFT即可共四次DFT求出最终解

注意精度,所以单位根不能一个个乘过去

code

洛谷模板

#include <bits/stdc++.h>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define ll long long
//#define file
using namespace std;

struct type{long double x,y;} p[262144],q[262144],F[262144];
type operator + (type a,type b) {return {a.x+b.x,a.y+b.y};}
type operator - (type a,type b) {return {a.x-b.x,a.y-b.y};}
type operator * (type a,type b) {return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}
int N,len,n,m,mod,i,j,k,l,a[262144],b[262144],c[262144],d[262144];
ll ac[200001],ad[200001],bc[200001],bd[200001],ans[200001];
const ll w=100000;

void dft(type *a,int tp)
{
	static type A[262144];
	int i,j,k,l,S=N,s1=2,s2=1;
	type u,v,w,W;
	
	fo(i,0,N-1)
	{
		j=i;k=0;
		fo(l,1,len)
		k=k*2+(j&1),j>>=1;
		A[i]=a[k];
	}
	memcpy(a,A,sizeof(A));
	
	fo(i,1,len)
	{
		S>>=1;
		fo(j,0,S-1)
		{
			fo(k,0,s2-1)
			{
				W={cos(2*M_PI*k/s1),sin(2*M_PI*k/s1)*tp};
				u=a[j*s1+k],v=a[j*s1+k+s2]*W;
				a[j*s1+k]=u+v;
				a[j*s1+k+s2]=u-v;
			}
		}
		s1<<=1,s2<<=1;
	}
}

int main()
{
	#ifdef file
	freopen("mtt.in","r",stdin);
	freopen("mtt.out","w",stdout);
	#endif
	
	scanf("%d%d%d",&n,&m,&mod);len=ceil(log2(n+m+1));N=pow(2,len);
	fo(i,0,n) scanf("%d",&j),a[i]=j/w,b[i]=j%w;
	fo(i,0,m) scanf("%d",&j),c[i]=j/w,d[i]=j%w;
	
	fo(i,0,n) p[i]={a[i],b[i]};
	dft(p,1);
	fo(i,0,N-1) q[i]={p[(N-i)%N].x,-p[(N-i)%N].y};
	fo(i,0,m) F[i]={c[i],d[i]};
	dft(F,1);
	
	fo(i,0,N-1) p[i]=p[i]*F[i],q[i]=q[i]*F[i];
	dft(p,-1),dft(q,-1);
	fo(i,0,N-1) p[i].x/=N,p[i].y/=N,q[i].x/=N,q[i].y/=N;
	fo(i,0,n+m) ac[i]=floor((p[i].x+q[i].x)/2+0.5),bd[i]=floor((q[i].x-p[i].x)/2+0.5),ad[i]=floor((p[i].y+q[i].y)/2+0.5),bc[i]=floor((p[i].y-q[i].y)/2+0.5);
	fo(i,0,n+m) ans[i]=(ac[i]%mod*w%mod*w+((ad[i]+bc[i])%mod)*w+bd[i])%mod;
	
	fo(i,0,n+m) printf("%lld ",(ans[i]+mod)%mod);printf("\n");
}
posted @ 2020-08-06 19:32  gmh77  阅读(171)  评论(0编辑  收藏  举报