转自 z55250825 的几篇关于FFT的博文(二)
题目大意:高精度乘法。
fft的实现貌似有很多种,咱先写的是一种递归的fft,应该算是比较快的了吧。参考了 Evil君 的代码,那个运算符重载看的咱P党泪流满面。 (没想到P竟然有运算符重载咩...)
先背模板再理解0.0
以下是待补的对模板的理解
{
其实讲的主要的关键就是如何递归,他记录了一个深度 t,一个左边界s(开区间的),和一个最后FFT的结果的数组a。
他实际上是在递归的过程中就已经计算好了叶子的了,所以复杂度是O(nlogn),咱们来看看咱们如何通过这些递归变量计算出咱们需要的需要计算的(实际上弄出这个来这个算法就基本上可以了)。
主要的代码段就是这里:
{
for i:=0 to n>>(t+1)-1 do
begin
p:=i<<(t+1)+s;
wt:=w[i<<t]*a[p+1<<t];
tt[i]:=a[p]+wt;
tt[i+n>>(t+1)]:=a[p]-wt;
end;
for i:=0 to n>>t-1 do a[i<<t+s]:=tt[i];
回顾之前的FFT算法,咱们在深度为K的时候,将A分成两半,实际上可以看做是将A按照二进制的第K位是否为1分成两半的。以这个咱们发现..咱们记录的那个S实际上就是当前进行FFT所包含的A的元素的下标二进制的公共部分....然后当前深度的合并的元素实际上就在这里面,且就是在S的二进制基础上在最高位加上一个二进制,即0~ n shr t,即S+[0,n shr t-1]shl (t+1)(这个实际上就是代码中的p了 0v0)
咱们攻克了第二行,可是第一行的那个循环还是有点不对劲的感觉...按道理应该是 0~n shr t-1的,可是这里是0~n shr t shr 1-1,实际上这就是要用蝴蝶操作了,所以只需要枚举一半的量,也就是说,这里实际上要确定两个进行蝴蝶操作的A的下标的二进制关系。
首先确定一点,咱们在整个算法中用到a的地方仅仅在于划分,所以咱们用它们的对应的位置存储y[],然后之后都是直接用 Y0+xY1这样的形式来计算,所以咱们要清楚循环里面存的实际上应该是y。当前所要计算的应该是 w(n>>t,)系列的,咱们将它们存储到对应的区间里面。(其实就是原来的!!!)
对于某一个区间给定的 s,t,咱们首先可以计算出这个里面涉及到的 a[],为[s,s+n>>t],然后分段,一段就是 i+1<<(t+1),其对应的应该就是+1<<t.
}
吐槽:这个模板还需要优化额..为什么感觉和普通的高精度比还是不太行= =
(不过去wikioi的那个10^5的测试貌似跑的还是蛮快的...但是目测常数略大...,因为10^5的数要800ms+跑完)
========================
program fft;
type cp=record x,y:double;end;
arr=array[0..1 shl 14]of cp;
var a,b,tt,w:array[0..1 shl 14]of cp;
c:array[0..1000010]of longint;
n,tot1,tot2:longint;
operator *(var a,b:cp)c:cp;
begin c.x:=a.x*b.x-a.y*b.y;c.y:=a.x*b.y+a.y*b.x;end;
operator +(var a,b:cp)c:cp;
begin c.x:=a.x+b.x;c.y:=a.y+b.y;end;
operator -(var a,b:cp)c:cp;
begin c.x:=a.x-b.x;c.y:=a.y-b.y;end;
procedure dft(var a:arr;s,t:longint);
var i,p:longint;
wt:cp;
begin
if n>>t=1 then exit;
dft(a,s,t+1);
dft(a,s+1<<t,t+1);
for i:=0 to n>>(t+1)-1 do
begin
p:=i<<(t+1)+s;
wt:=w[i<<t]*a[p+1<<t];
tt[i]:=a[p]+wt;
tt[i+n>>(t+1)]:=a[p]-wt;
end;
for i:=0 to n>>t-1 do a[i<<t+s]:=tt[i];
end;
procedure init;
var ch:char;
i,k:longint;
j:cp;
begin
read(ch);tot1:=0;tot2:=0;
while (ord(ch)>=ord('0'))and(ord(ch)<=ord('9')) do
begin
a[tot1].x:=ord(ch)-ord('0');
read(ch);
inc(tot1);
end;
read(ch);
while (ord('0')<=ord(ch)) and(ord(ch)<=ord('9')) do
begin
b[tot2].x:=ord(ch)-ord('0');
read(ch);
inc(tot2);
end;
dec(tot1);dec(tot2);
for i:=0 to tot1 shr 1 do
begin
j:=a[i];a[i]:=a[tot1-i];a[tot1-i]:=j;
end;
for i:=0 to tot2 shr 1 do
begin
j:=b[i];b[i]:=b[tot2-i];b[tot2-i]:=j;
end;
if tot1<tot2 then tot1:=tot2;
n:=1;
while n>>1<(tot1+1) do n:=n shl 1;
for i:=0 to n-1 do w[i].x:=cos(pi*2*i/n);
for i:=0 to n-1 do w[i].y:=sin(pi*2*i/n);
dft(a,0,0);dft(b,0,0);
for i:=0 to n-1 do w[i].y:=-w[i].y;
for i:=0 to n-1 do a[i]:=a[i]*b[i];
dft(a,0,0);
fillchar(c,sizeof(c),0);
for i:=0 to n-1 do
begin
c[i]:=c[i]+round(a[i].x/n);
c[i+1]:=c[i] div 10;
c[i]:=c[i] mod 10;
end;
i:=n;
while (c[i]=0)and(i>0) do dec(i);
for k:=i downto 0 do write(c[k]);
end;
begin
init;
end.