[模板]生成函数与NTT、FFT

一、前置知识

二、生成函数

组合计数中一般定义了一类组合对象A,每个组合对象有不同的大小(即xn),要求每个大小的组合对象的数量An

1、无标号计数:普通型生成函数OGF(Ordinary Generating Function)

生成函数A(x)=∑Aixi (i>=0)

一般A表示该类组合对象,xi表示大小为i的组合对象,Ai表示大小为i的组合对象的数量。

设有两类无标号组合对象A、B,若一类新的组合对象D满足组合对象xDn=xAi×xBj (i+j=n),即D中的大小为n的组合对象 可由A和B中的组合对象拼接而成 

则Di=∑i+j=n AiBj ,DxDn=∑i+j=n Ai xAi  ×  BxBj  

该式为卷积形式,可以用FFT或NTT加速。

实际上普通型生成函数中的xi不一定表示大小为i的组合对象,i也可代表序号,如斐波那契数列的生成函数:

例题:

2、有标号计数:指数型生成函数EGF (Exponential Generating Function)

生成函数A(x)=∑Axi/i! (i>=0)

一般A表示该类组合对象,xi/i! 表示大小为i的组合对象,Ai表示大小为i的组合对象的数量。

有标号的组合对象在拼接时与无标号对象不同。

设原本A中有某个大小为n组合对象xn,组合对象内的对象对应不同标号1~n。B中有某个大小为m的组合对象xm,组合对象内的对象对应不同标号1~m。

有标号组合对象拼接时需要重新为大小为n+m的组合对象xn+m分配标号,且保持原有标号顺序不变,从n+m个对象中选n个对象使其原始标号为1~n,另外m个对象的原始标号为1~m。

因此需要乘以系数C(n+m,n) ,故Dn =∑i+j=n Ai ×  Bj  × C(n,i)

即 DxDn/n! =∑i+j=n Ai xAi /i!  ×  BxBj /j!

该式为卷积形式,可用NTT或FTT加速。

例题:

其中xn/n!代表 n个格子的组合对象,系数代表n个格子的方案数。一开始只有一种方案,就是n个格子全填白/蓝/红(0个格子蓝色/奇数格子红色 的方案数设为0)。

1、比如在对蓝色生成函数和白色生成函数做乘积的时候,设i+j=n 0<=i<=n

则相当于先枚举i,j,然后从n个格子中确定i个格子作为白色,另外j个格子作为蓝色,不同格子视为不同标号,因此为有标号计数。

2、蓝白格子 再和红色格子做乘积的时候,设i+j=n 0<=i<=n

则相当于先枚举i,j,然后从n个格子中确定i个格子作为蓝色白色,另外j个格子作为红色,这个再乘上这i个格子为蓝/白的方案,再乘上这j个格子为红的方案。

注意到在做D(x)=A(x)*B(x) 时A和B的方案应该完全不同,如A为蓝/白,B为红。

3、例题

这题可将骨牌大小之和 认为是组合对象的大小,那么答案就是Dn

问题在于初始的生成函数怎么设?

思路一:令Ai(x)= aixi+ ai2x2i +ai3x3i+...表示 只用类型i的骨牌的方案数

发现在合并两个类型骨牌Ai,Aj的时候如果直接做系数卷积(OGF),那么相当于确定在i后面放j类型骨牌,j类型骨牌必定放在一起,答案偏小。

如果做系数卷积(EGF),由于一个骨牌的大小为1*i,不可拆分,做EGF是C(n+m,n)会认为格子可拆分,会使答案偏大。

思路二:令A(x)=a1x1+a2x2+...+anxn 表示只用1个骨牌的方案数

那么Ak(x)就是使用k个骨牌的方案数。可以认为在做系数卷积的时候直接在k-1个骨牌后面放1个骨牌,故为OGF。

枚举所用的骨牌数量k,把Ak(x)在xn的系数累加起来,发现答案为1/(1-A(x))在xn的系数

好像要用到多项式求逆,不会

4、trick

三、快速傅里叶变换FFT(Fast Fourier Transformation)

参考资料:

https://blog.csdn.net/Flag_z/article/details/99163939
https://www.cnblogs.com/pks-t/p/9251147.html 
https://www.zhihu.com/question/22298352

快速傅里叶变换原理:
复数乘法:模长相乘幅角相加 根据欧拉公式ke=kcos θ+iksinθ
变换前提:n=2k; 如果n<2k,那么就往后补系数0 
A(x)=a0+a2x2+a4x4+...+an-2xn-2+x*(a1+a3x2+a5x4+...+an-1xn-1)
      =A0(x2)+x*A1(x2) 将系数分为奇数位和偶数位 
wnk 单位根的k次幂 作为点值表示法的点
A(wnk)=A0(wn2k)+wnk*A1(wn2k)
           =A0(wn/2k)+wnk*A1(wn/2k) 这个性质有利于变成子问题 
A(wnk+n/2)=A0(wn2k+n)+wnk+n/2*A1(wn2k+n)
                  =A0(wn/2k)-wnk*A1(wn/2k) 后面一半的单位根的解和前面一半的形式类似 
由n/2的子问题可以得到n的答案,故一直/2,一共log层
上述算法的关键:单位根的优秀性质 

考虑系数的排列
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 
0 2 4 6 8 10 12 14|1 3 5 7 9 11 13 15 :                                   ...0                                      |                                    ...1 
0 4 8 12|2 6 10 14|1 5 9 13|3 7 11 15 :               ...00                |               ...10                 |              ...01                 |               ...11 
0 8|4 12|2 10|6 14|1 9|5 13|3 11|7 15 :    ...000     |    ...100     |    ...010      |   ...110      |    ...001     |     ...101    |     ...011     |     ...111
0|8|4|12|2|10|6|14|1|9|5|13|3|11|7|15 :0000|1000|0100|1100|0010|1010|0110|1110|0001|1001|0101|1101|0011|1011|0111|1111

可以发现二进制位最后一位优先级最高,从后向前优先级逐渐下降,换句话说,一行里面,后面一个数的二进制翻转都是前面一个数的二进制翻转+1。

即每个位置实际上都要替换为这个位置的二进制翻转对应的系数,由于二进制翻转是对称的,故直接swap()即可。
那么,位置x的系数对应的新的位置为rev[x] 然后就可以从下往上推了
注:A[len>>1][i]表示A(wleni%len) 每一层的函数A都是由下一层的两个函数A合成的,且第一维的len可以省略从而节省空间开销 
以上操作为把一个函数从系数表示法变成点表示法,从而使得乘法加速到O(n)

接下来考虑把函数从点表示法变成系数表示法 即IFFT 逆 快速 傅里叶变换 
已知点表示法为(wn0,A(wn0)),(wn1,A(wn1)),...,(wnn-1,A(wnn-1))
把A(wn0),A(wn1),...,A(wnn-1)当成新的多项式的系数,则B(x)=A(wn0)+A(wn1)*x+...
再考虑转化为点表示法:(注:此时单位根采用w(-1,n)) 

B(wn-k)=A(wn0)+A(wn1)*wn-k+A(wn2)*wn-2k+...+A(wnn-1)*wn-nk

           =∑0≤i≤n-1(       A(wni)         *       (wn-k)i     

           =∑0≤i≤n-1(  ∑0≤j≤n-1( aj*(wni)j )   * (wn-k))

           =∑0≤i≤n-1(  ∑0≤j≤n-1( aj*(wni)j * (wn-k)i    ) ) 

      =∑0≤j≤n-1(  ∑0≤i≤n-1( aj*(wni)j * (wn-k)i    ) )

           =∑0≤j≤n-1(     aj*   ∑0≤i≤n-1(wnj-k)i            )

若 j-k=0 则0≤i≤n-1(wnj-k)i=n ;若j-k!=0 则0≤i≤n-1(wnj-k)i= 0(利用等比数列求和发现分子=1- wn(j-k)n=0)
即:B(wn-k)=n*a

卷积: f*g(n)=(f(i)*g(n-i))
意义:把g以0为中心垂直翻转再向右平移n得到g',再和 f对应项相乘再相加 即g'(i)=g(n-i)
如n=0的情形 
f :                         0 1 2 3 4 5 6... 
g' :... 6 5 4 3 2 1 0
如n=2的情形 
f :                  0 1 2 3 4 5 6... 
g' :... 6 5 4 3 2 1 0 
容易发现 上下对应的两个值他们的和为n 

A*B Problem FFT版
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define Ford(i,a,b) for(int i=a;i>=b;i--)
#define For(i,a,b) for(int i=a;i<=b;i++)
const int N=2e5+1000;
const double Pi=acos(-1.0);
struct Complex
{
	double real,imag;
	Complex operator +(const Complex& x) {return (Complex){real+x.real,imag+x.imag};}
	Complex operator -(const Complex& x) {return (Complex){real-x.real,imag-x.imag};}
	Complex operator *(const Complex& x) {return (Complex){real*x.real-imag*x.imag,real*x.imag+imag*x.real};}
};
int rev[N],n;
void FFT(Complex *A,int flag)
{
	For(i,0,n-1) if (i<rev[i]) std::swap(A[i],A[rev[i]]);
	for(int w=1;w<n;w<<=1)
	{
		Complex Wn=(Complex){cos(Pi/w),sin(Pi/w)*flag};
		for(int l=0;l<n;l+=(w<<1))
		{
			Complex W=(Complex){1,0};
			for(int k=0;k<w;k++,W=W*Wn) 
			{
				Complex x=A[l+k],y=W*A[l+w+k];
				A[l+k]=x+y;
				A[l+k+w]=x-y;
			}
		}
	}
	if (flag==-1) For(i,0,n-1) A[i].real=(int)(A[i].real/n+0.5);
}
char s[N];int len,m;
Complex A[N],B[N];
int main()
{
	scanf("%d",&len);m+=2*(len-1);
	scanf("%s",s);
	For(i,0,len-1) A[i].real=s[len-i-1]-'0';
	scanf("%s",s);
	For(i,0,len-1) B[i].real=s[len-i-1]-'0';
	n=1;while (n<m+1) n<<=1;
	For(i,0,n-1) rev[i]=(rev[i>>1]>>1)+((i&1)?(n>>1):0);//先翻i的前面一部分,再在最前面加上i的末位
    //i和rev[i]互为二进制前后翻转
	FFT(A,1);FFT(B,1);
	For(i,0,n-1) A[i]=A[i]*B[i];
	FFT(A,-1);
	For(i,0,n-1) A[i+1].real+=(int)A[i].real/10,A[i].real-=(int)A[i].real/10*10;
	int now=m+1;while (A[now].real==0) now--;
	Ford(i,now,0) printf("%d",(int)A[i].real);putchar('\n');
	return 0;
}

四、快速数论变换NTT(Number Theoretic Transform)

FFT利用单位根w的性质实现分治优化多项式乘法,原根也有这用的性质。

NTT即用原根代替FFT中的w,从而实现不用复数避免精度损失,同时实现取模。

若a,p互素,且p>1,对于an≡1(mod p)最小的n,我们称之为a模p的阶,记做δp(a)=n (a0≡an≡a2n≡...≡1)

设p是正整数,a是整数,若δp(a)等于ϕ(p),则称a为模p的一个原根 (ϕ(p)为p的欧拉函数)

即原根满足aϕ(p)≡1(mod p)。如果p是一个质数,则aϕ(p)=ap-1≡1(mod p)

另外原根还满足gi mod p (0≤i≤p-1)互不相同

常见模数与原根:998244353、1004535809、469762049的原根都为3

对于质数p=qn+1(n=2m),原根g满足gp-1=gqn≡1(mod p)

令gn=g(p-1)/n=gq则gn满足

①gnn=(g(p-1)/n)n=gp-1≡1(mod p)

②gnn/2=(g(p-1)/n)n/2=g(p-1)/2

由于gp-1≡1(mod p),故g(p-1)/2只有两种取值:1和-1

由于原根满足gi mod p (0≤i≤p-1)互不相同,故g (p-1)/2和g (p-1)必须不同

因此g(p-1)/2只能取-1,即gnn/2=-1

③g2n2k=gp-1/2n*2k =gp-1/n*k=gnk

④gnk+n/2=gnk×gnn/2=-gnk

故gn可代替FFT中的单位根wn进行类似推导:

A(gnk)=A0(gn2k)+gnk*A1(gn2k)
           =A0(gn/2k)+gnk*A1(gn/2k) 这个性质有利于变成子问题 
A(gnk+n/2)=A0(gn2k+n)+gnk+n/2*A1(gn2k+n)
                  =A0(gn/2k)-gnk*A1(gn/2k) 后面一半的单位根的解和前面一半的形式类似 

A*B problem NTT
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define Frd(i,a,b) for(int i=a;i>=b;i--)
#define For(i,a,b) for(int i=a;i<=b;i++)
const int N=4e6+1000;
const long long mod=998244353;
const long long g=3;
int rev[N],n;
long long power(long long x,long long y)
{
	x%=mod;
	long long ans=1;
	while (y)
	{
		if (y&1) ans=ans*x%mod;
		x=x*x%mod;y>>=1;
	}
	return ans; 
}
void NTT(long long *A,int flag)
{
	For(i,0,n-1) if (i<rev[i]) std::swap(A[i],A[rev[i]]);
	For(i,0,n-1) A[i]%=mod;
	for(int w=1;w<n;w<<=1)
	{
		long long gn=power(g,flag==1? (mod-1)/(w<<1) :mod-1- (mod-1)/(w<<1));//g(1,n)=g^(p-1/n) ,g(-1,n)=g(n-1,n)=g(1,n)^(n-1)
		for(int l=0;l<n;l+=(w<<1))
		{
			long long gnk=1;
			for(int k=0;k<w;k++,gnk=gnk*gn%mod) 
			{
				long long x=A[l+k],y=gnk*A[l+w+k]%mod;
				A[l+k]=x+y<mod?x+y:x+y-mod;
                A[l+k+w]=x-y>=0?x-y:x-y+mod;
			}
		}
	}
	if (flag==-1) //accumulate
    {
        long long inv=power(n,mod-2);
        For(i,0,n-1) A[i]=A[i]*inv%mod;
    }
}
char s[N];int len,m;
long long A[N],B[N];
int main()
{
	scanf("%s",s);len=strlen(s);m+=len-1;
	For(i,0,len-1) A[i]=s[len-i-1]-'0';
	scanf("%s",s);len=strlen(s);m+=len-1;
	For(i,0,len-1) B[i]=s[len-i-1]-'0';
	n=1;while (n-1<m) n<<=1;//最高位是10^(len1-1+len2-1),n-1至少要大于等于最高位
	For(i,0,n-1) rev[i]=(rev[i>>1]>>1)+((i&1)?(n>>1):0);//先翻i的前面一部分,再在最前面加上i的末位
    //i和rev[i]互为二进制前后翻转
	NTT(A,1);NTT(B,1);
	For(i,0,n-1) A[i]=A[i]*B[i]%mod;
	NTT(A,-1);
	For(i,0,n-1) A[i+1]+=A[i]/10,A[i]%=10;
	int now=n-1;while (A[now+1]) now++,A[now+1]+=A[now]/10,A[now]%=10;//由于最高位可能会大于9即进位,如9*9=81,故长度可能>n-1 
	while (A[now]==0) now--;
	Frd(i,now,0) printf("%lld",A[i]);putchar('\n');
	return 0;
}

注意到rev[]、gnk都可以预处理,另外如果有很多小多项式要合并成大多项式用vector会比较慢,不如用malloc自己写一个结构体

另外当flag=-1时把power(n,mod-2)记录下来也可以优化常数

不过这个写法为了缩小常数,拷贝构造函数等不进行memcpy,而是直接复制指针。

因此每次a*b都会多产生一个malloc,要记得free掉

A*B problem 小常数NTT
//used for merge a lot of poly,which size is small 
//by the way,this program try the best to accumulate
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#define For(i,a,b) for(int i=a;i<=b;i++)
#define Frd(i,a,b) for(int i=a;i>=b;i--)
namespace Poly
{
	const int M=(1<<22)+1000;//1e6+1e6=2e6, 
	const long long mod=998244353;
	const long long g=3;
	int now_logn=-1,now_n=0,now_rev[M];long long now_gnk[2][M];//buffer ,in order to accumulate 	
	long long power(long long x,long long y)
	{
		x%=mod;
		long long ans=1;
		while (y)
		{
			if (y&1) ans=ans*x%mod;
			x=x*x%mod;y>>=1;
		}
		return ans; 
	}
	void Init(int logn)
	{
		if (now_logn==logn) return ;
		now_logn=logn;now_n=(1<<now_logn);
		For(i,0,now_n-1) now_rev[i]=(now_rev[i>>1]>>1)+(i&1?now_n>>1:0);//reverse the bit in front,then reverse the bit at last
		for(int w=1;w<now_n;w<<=1)//1   10       100             1000
		{
			long long now_gn[2],temp_gnk[2];
			now_gn[0]=power(g,(mod-1)/(w<<1));//g(1,n)=g^(p-1/n)
			now_gn[1]=power(g,mod-1- (mod-1)/(w<<1));//g(-1,n)=g(n-1,n)=g(1,n)^(n-1)
			temp_gnk[0]=1;temp_gnk[1]=1;
			for(int k=0;k<w;k++)//1   10,11   100,101,110,111  1000,1001,1010,1011,1100,1101,1110,1111 
			{
				now_gnk[0][w+k]=temp_gnk[0];
				now_gnk[1][w+k]=temp_gnk[1];
				temp_gnk[0]=temp_gnk[0]*now_gn[0]%mod;
				temp_gnk[1]=temp_gnk[1]*now_gn[1]%mod;
			}
		}
	}
	struct poly
	{
		long long * A;//use "long long*" instead of "vector<long long>" ,in order to accumulate 
		int n,logn; 
		poly(int len,long long *B=0)//x^0 +... +x^n-1
		{
			n=1;logn=0;while (n<len) n<<=1,logn++;
			A=(long long*) malloc(sizeof(long long)*n);
			if (B!=0) 
			{
				For(i,0,len-1) 
				{
					A[i]=B[i]%mod;
					if (A[i]<0) A[i]+=mod;
				}
				if (len<n) memset(A+len,0,sizeof(long long)*(n-len));
			}
			else memset(A,0,sizeof(long long)*n);
		}
		poly(const poly&other)
		{
			n=other.n;
			logn=other.logn;
//			memcpy(A,other.A,n);//memcpy need a lot of time 
			A=other.A;
		}
		~poly()
		{
//			free(A);// it's need to be free at outside,because memcpy need a lot of time 
		}
		void copy(int len=-1)
		{
			if (len==-1) //copy_to_new_momory
			{
				long long *temp=A;
				A=(long long*) malloc(sizeof(long long)*n);
				memcpy(A,temp,n);
			}
			else//copy_to_new_momory and be bigger
			{
				if (len<=n) return ;
				long long *tempA=A;
				int tempn=n;
				while (n<len) n<<=1,logn++;
				A=(long long*) malloc(sizeof (long long)*n);
				memcpy(A,tempA,sizeof(long long)*tempn);
				memset(A+tempn,0,sizeof(long long)*(n-tempn));
			}
		}
		void resize(int len)
		{
			if (len<=n) return ;
			long long *temp=A;
			copy(len);
			free(temp);
		}
		void NTT(int len,int flag)
		{
			resize(len);Init(logn); 
			For(i,0,n-1) if (i<now_rev[i]) std::swap(A[i],A[now_rev[i]]);
			For(i,0,n-1) A[i]%=mod;
			for(int w=1;w<n;w<<=1)
			{
//				long long gn=power(g,flag==1? (mod-1)/(w<<1) :mod-1- (mod-1)/(w<<1));//g(1,n)=g^(p-1/n) ,g(-1,n)=g(n-1,n)=g(1,n)^(n-1)
				for(int l=0;l<n;l+=(w<<1))
				{
					long long *gnk=flag==1?now_gnk[0]+w:now_gnk[1]+w;
//					long long gnk=1;
					for(int k=0;k<w;k++) 
					{
						long long x=A[l+k],y=gnk[k]*A[l+w+k]%mod;
						A[l+k]=x+y<mod?x+y:x+y-mod;
						A[l+k+w]=x-y>=0?x-y:x-y+mod;
					}
				}
			}
			if (flag==-1) 
			{
				long long inv=power(n,mod-2);//record this in order to accumulate 
				For(i,0,n-1) A[i]=A[i]*inv%mod;
			}
		}
		poly operator + (const poly &other) 
		{
			poly temp(std::max(n,other.n));
			For(i,0,temp.n-1) 
			{
				if (i<n) temp.A[i]+=A[i];
				if (i<other.n) temp.A[i]+=other.A[i];
				temp.A[i]%=mod;
			}
			return temp;
		}
		poly operator - (const poly &other) 
		{
			poly temp(std::max(n,other.n));
			For(i,0,temp.n-1) 
			{
				if (i<n) temp.A[i]+=A[i];
				if (i<other.n) temp.A[i]-=other.A[i];
				temp.A[i]%=mod;
				if (temp.A[i]<0) temp.A[i]+=mod;
			}
			return temp;
		}
		poly operator * (const poly &other) //c=a*b  //return new memory
		{
			poly temp1(*this);temp1.copy(n+other.n-1); //x^n-1 * x^m-1 =x^(n+m-2) 0~n+m-2 =n+m-1
			poly temp2(other);temp2.copy(n+other.n-1);
			temp1.NTT(n+other.n-1,1); temp2.NTT(n+other.n-1,1); 
			For(i,0,temp1.n-1) temp1.A[i]=temp1.A[i]*temp2.A[i]%mod;
			temp1.NTT(n+other.n-1,-1); free(temp2.A);
			return temp1;
		}
		poly& operator *= (const poly &other)//doesn't return new momory
		{
			this->copy(n+other.n-1);
			poly temp(other);temp.copy(n+other.n-1);
			For(i,0,n-1) A[i]=A[i]*temp.A[i]%mod;
			this->NTT(n+other.n-1,-1) ;free(temp.A);
			return *this;
		}
	};
} 
const int N=4e6+1000;
char s[N];int len=0;
long long temp[N];
int main()
{
	scanf("%s",s);len=strlen(s);
	For(i,0,len-1) temp[i]=s[len-1-i]-'0';
	Poly::poly a(len,temp);
	scanf("%s",s);len=strlen(s);
	For(i,0,len-1) temp[i]=s[len-1-i]-'0';
	Poly::poly b(len,temp);
	Poly::poly c=a*b;
	For(i,0,c.n-1) temp[i]=0;
	For(i,0,c.n-1) temp[i]+=c.A[i],temp[i+1]+=temp[i]/10,temp[i]%=10;
	int top=c.n-1;while (temp[top+1]) temp[top+1]+=temp[top]/10,temp[top]%=10,top++;
	while (!temp[top]&&top) top--;
	Frd(i,top,0) printf("%d",temp[i]);putchar('\n'); 
	return 0;
}

 

posted @   th-is  阅读(170)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· AI技术革命,工作效率10个最佳AI工具
点击右上角即可分享
微信分享提示