【洛谷P4245】【模板】任意模数多项式乘法

题目

题目链接:https://www.luogu.com.cn/problem/P4245
给定 \(2\) 个多项式 \(F(x), G(x)\) ,请求出 \(F(x) * G(x)\)
系数对 \(p\) 取模,且不保证 \(p\) 可以分解成 \(p = a \cdot 2^k + 1\) 之形式。
\(n\leq 10^5;a,b\leq 10^9;p\leq 10^9+9\)

思路

不保证 \(p\) 是 NTT 模数,所以不能直接用 NTT 做。
一般有两种处理方法,一是用三模数 NTT,最后 CRT 合并,精度可达 \(10^{26}\)。另一种是拆位 FFT。也称 MTT。
拆位 FFT 思路很简单,将多项式每一位系数拆成两部分 \(a_1,a_2\),其中 \(a_1\)\(a\) 在二进制下前十五位,\(a_2\)\(a\) 在二进制下后十五位。
那么

\[(F*G)[i]=(2^{15}a_1[i]+a_2[i])\times (2^{15}b_1[i]+b_2[i]) \]

\[=2^{30}a_1[i]a_2[i]+2^{15}(a_1[i]b_2[i]+a_2[i]b_1[i])+a_2[i]b_2[i] \]

这样一顿乱搞之后,卷积后每一项系数不超过 \(n\times (p^{\frac{1}{2}})^2=np\leq 10^{14}\)。可以用 long double 强行存下。
这样的话只需要 \(4\) 次 DFT 和 \(3\) 次 IDFT 即可。
但是依然有更优的办法。我们设复多项式 \(F'[i]=a_1[i]+a_2[i]i\)\(G'[i]=a_1[i]-a_2[i]i\)\(H'[i]=b_1[i]+b_2[i]i\),那么

\[(F'*H')[i]=a_1[i]b_1[i]-a_2[i]b_2[i]+i(a_1[i]b_2[i]+a_2[i]b_1[i]) \]

\[(G'*H')[i]=a_1[i]b_1[i]+a_2[i]b_2[i]+i(a_1[i]b_2[i]-a_2[i]b_1[i]) \]

那么

\[(F'*H')[i]+(G'*H')[i]=2a_1[i]b_1[i]+2a_1[i]b_2[i]i \]

\[(F'*H')[i]-(G'*H')[i]=-2a_2[i]b_2[i]+2a_2[i]b_1[i]i \]

我们只需要 \(3\) 次 DFT 和 \(2\) 次 IDFT 即可。
注意为了保证精度,最好预处理单位根。
时间复杂度 \(O(n\log n)\)
事实上我写的 5 次 FFT 还没有 QuantAsk 7 次的做法快 /kk。

代码

#include <bits/stdc++.h>
#define cp complex<long double>
using namespace std;
typedef long long ll;
typedef long double ld;

const int N=300010;
const ld pi=acos(-1);
int n,m,lim,MOD,rev[N];
cp f[N],g[N],h[N],w[N];

void FFT(cp *f,int tag)
{
	for (int i=0;i<lim;i++)
		if (rev[i]<i) swap(f[rev[i]],f[i]);
	w[0]=cp(1,0);
	for (int k=1;k<lim;k<<=1)
	{
		cp tmp(cos(pi/k),tag*sin(pi/k));
		for (int i=k-2;i>=0;i-=2)
			w[i]=w[i>>1],w[i+1]=w[i]*tmp;
		for (int i=0;i<lim;i+=(k<<1))
		{
			for (int j=0;j<k;j++)
			{
				cp x=f[i+j],y=w[j]*f[i+j+k];
				f[i+j]=x+y; f[i+j+k]=x-y;
			}
		}
	}
}

int main()
{
	scanf("%d%d%d",&n,&m,&MOD);
	for (int i=0,x;i<=n;i++)
	{
		scanf("%d",&x);
		f[i]=cp(x>>15,x&32767);
		g[i]=cp(x>>15,-(x&32767));
	}
	for (int i=0,x;i<=m;i++)
	{
		scanf("%d",&x);
		h[i]=cp(x>>15,x&32767);
	}
	lim=1;
	while (lim<=n+m) lim<<=1;
	for (int i=0;i<lim;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)?(lim>>1):0);
	FFT(f,1); FFT(g,1); FFT(h,1);
	for (int i=0;i<lim;i++)
		f[i]*=h[i],g[i]*=h[i];
	FFT(f,-1); FFT(g,-1);
	for (int i=0;i<=n+m;i++)
	{
		ll a=(ll)((f[i]+g[i]).real()/2.0/lim+0.4999)%MOD;
		ll b=(ll)((f[i]+g[i]).imag()/2.0/lim+0.4999)%MOD;
		ll c=(ll)((f[i]-g[i]).imag()/2.0/lim+0.4999)%MOD;
		ll d=(ll)((g[i]-f[i]).real()/2.0/lim+0.4999)%MOD;
		printf("%lld ",(((1LL<<30)*a+(1LL<<15)*(b+c)+d)%MOD+MOD)%MOD);
	}
	return 0;
}
posted @ 2021-01-18 18:27  stoorz  阅读(63)  评论(0编辑  收藏  举报