FFT练习

FFT板子 

const double pi = acos(-1);
const int N = 1e6+10;
struct _ {double x, y;}A[N],B[N];
_ operator * (_ a,_ b) {return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}
_ operator + (_ a,_ b) {return {a.x+b.x,a.y+b.y};}
_ operator - (_ a,_ b) {return {a.x-b.x,a.y-b.y};}

int n,m,x,s,lim,l;
int a[N],R[N];
ll f[N];

void init(int n) {
	for (lim=1,l=0; lim<=n; lim<<=1,++l) ;
	REP(i,0,lim-1) R[i]=(R[i>>1]>>1)|((i&1)<<(l-1));
}

void FFT(_ *J, int tp) {
	REP(i,0,lim-1) if (i<R[i]) swap(J[i],J[R[i]]);
	for (int j=1; j<lim; j<<=1) {
		_ T = {cos(pi/j), tp*sin(pi/j)};
		for (int k=0; k<lim; k+=(j<<1)) {
			_ t = {1, 0};
			for (int l=0; l<j; ++l, t=t*T) {
				_ y = t*J[k+j+l];
				J[k+j+l] = J[k+l]-y;
				J[k+l] = J[k+l]+y;
			}
		}
	}
	if (tp==-1) A[i].x/=lim;
}
 
void mul(ll *a, ll *b, ll *c) {
	REP(i,0,lim-1) A[i].x=a[i],B[i].x=b[i];
	FFT(A,1),FFT(B,1);
	REP(i,0,lim-1) A[i]=A[i]*B[i];
	FFT(A,-1);
	REP(i,0,lim-1) c[i]=A[i].x+0.5;
}

 

NTT板子 (中间过程未考虑负数, 最后答案要判负)

const int N = 1e6+10, P = 998244353, G = 3, Gi = 332748118;
int lim,l,A[N],B[N],R[N];
void init(int n) {
	for (lim=1,l=0; lim<=n; lim<<=1,++l) ;
	REP(i,0,lim-1) R[i]=(R[i>>1]>>1)|((i&1)<<(l-1));
}

void NTT(int *J, int tp=1) {
	REP(i,0,lim-1) if (i<R[i]) swap(J[i],J[R[i]]);
	for (int j=1; j<lim; j<<=1) {
		ll T = qpow(tp==1?G:Gi,(P-1)/(j<<1));
		for (int k=0; k<lim; k+=j<<1) {
			ll t = 1;
			for (int l=0; l<j; ++l,t=t*T%P) {
				int y = t*J[k+j+l]%P;
				J[k+j+l] = (J[k+l]-y)%P;
				J[k+l] = (J[k+l]+y)%P;
			}
		}
	}
	if (tp==-1) {
		ll inv = qpow(lim, P-2);
		REP(i,0,lim-1) J[i]=(ll)inv*J[i]%P;
	}
}

void mul(int *a, int *b, int *c) {
	REP(i,0,lim-1) A[i]=a[i],B[i]=b[i];
	NTT(A),NTT(B);
	REP(i,0,lim-1) c[i]=(ll)A[i]*B[i]%P;
	NTT(c,-1);
}

 

练习1. 牛客201 I Steins;Gate

大意: 给定$n$元素序列$a$, 给定模数$P$, 对于$1\le k\le n$, 求出$a_ia_j \% P == a_k$的有序二元组$(i,j)$个数.

令$a_i=g^{A_i}$, $g$为原根.

可以得到$A_i+A_j \equiv a_k \space (mod\space P-1)$

所以求一次卷积即可.

#include <iostream>
#include <string.h>
#include <math.h>
#define REP(i,a,n) for(int i=a;i<=n;++i)
using namespace std;
typedef long long ll;
ll qpow(ll a,ll n,ll m) {ll r=1%m;for (a%=m;n;a=a*a%m,n>>=1)if(n&1)r=r*a%m;return r;}
int rt(int m) {
    REP(i,2,m) {
        int x=m-1, ok=1, mx = sqrt(m-0.5);
        REP(k,2,mx) if (x%k==0) {
            if (qpow(i,(m-1)/k,m)==1) {ok=0;break;}
            while (x%k==0) x/=k;
        }
        if (ok&&(x==1||qpow(i,(m-1)/x,m)>1)) return i;
    }
    throw;
}


const double pi = acos(-1);
const int N = 1e6+10;
struct _ {double x, y;}A[N],B[N];
_ operator * (_ a,_ b) {return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}
_ operator + (_ a,_ b) {return {a.x+b.x,a.y+b.y};}
_ operator - (_ a,_ b) {return {a.x-b.x,a.y-b.y};}

int n,m,x,s,lim,l;
int a[N],g[N],no[N],R[N];
ll f[N];

void init(int n) {
    for (lim=1,l=0; lim<=n; lim<<=1,++l) ;
    REP(i,0,lim-1) R[i]=(R[i>>1]>>1)|((i&1)<<(l-1));
}

void FFT(_ *J, int tp) {
	REP(i,0,lim-1) if (i<R[i]) swap(J[i],J[R[i]]);
	for (int j=1; j<lim; j<<=1) {
		_ T = {cos(pi/j), tp*sin(pi/j)};
		for (int k=0; k<lim; k+=(j<<1)) {
			_ t = {1, 0};
			for (int l=0; l<j; ++l, t=t*T) {
				_ y = t*J[k+j+l];
				J[k+j+l] = J[k+l]-y;
				J[k+l] = J[k+l]+y;
			}
		}
	}
}
 
void mul(ll *a, ll *b, ll *c) {
    REP(i,0,lim-1) A[i].x=a[i],B[i].x=b[i];
	FFT(A,1),FFT(B,1);
    REP(i,0,lim-1) A[i]=A[i]*B[i];
    FFT(A,-1);
	REP(i,0,lim-1) c[i]=A[i].x/lim+0.5;
	REP(i,0,m-2) c[i]=c[i]+c[i+m-1];
}
 
int main() {
	scanf("%d%d", &n, &m);
	g[0] = 1, g[1] = rt(m);
	REP(i,1,m-2) g[i] = (ll)g[i-1]*g[1]%m;
	REP(i,0,m-2) no[g[i]]=i;
	int cnt = 0;
	REP(i,1,n) {
		scanf("%d",a+i);
		int t = a[i]%m;
		if (t) ++f[no[t]];
		else ++cnt;
	}
	init(2*m);
	mul(f,f,f);
	REP(i,1,n) {
		ll ans = 0;
		if (a[i]<m) {
			if (a[i]) ans = f[no[a[i]]];
			else ans = (ll)cnt*n+(ll)(n-cnt)*cnt;
		}
		printf("%lld\n", ans);
	}
}

 

练习2. bzoj 3992: [SDOI2015]序列统计

设使用$i$个数, 前缀积模$m$为$x$的方案数为$f[i][x]$

有转移$$f[i][x]=\sum\limits_{ab=x}f[i-1][a]f[1][b]$$

用原根表示也就是
$$f[i][g^X]=\sum\limits_{A+B\equiv X(mod\space m-1)}f[i-1][g^A]f[1][g^B]$$

这明显是个卷积形式, 也就是说构造多项式
$$A_i(x)=\sum\limits_{0\le k\le m-2} f[i][g^k]x^k$$

那么有

$$A_{i}=A_{i-1}(x)*A_1(x) \space (mod \space x^{m-1})$$

也就是说$A_n=A_1^n(x)$, 快速幂优化一下即可达到$O(mlogmlogn)$

#include <iostream>
#include <string.h>
#include <math.h>
#define REP(i,a,n) for(int i=a;i<=n;++i)
using namespace std;
typedef long long ll;
const int P = 1004535809, G = 3, Gi = 334845270;
ll qpow(ll a,ll n,ll m=P) {ll r=1%m;for (a%=m;n;a=a*a%m,n>>=1)if(n&1)r=r*a%m;return r;}
int rt(int m) {
    REP(i,2,m) {
        int x=m-1, ok=1, mx = sqrt(m-0.5);
        REP(k,2,mx) if (x%k==0) {
            if (qpow(i,(m-1)/k,m)==1) {ok=0;break;}
            while (x%k==0) x/=k;
        }
        if (ok&&(x==1||qpow(i,(m-1)/x,m)>1)) return i;
    }
    throw;
}

const int N = 1e6+10;
int n,m,x,s,lim,l;
int a[N],g[N],no[N],f[N],A[N],B[N],R[N];


void init(int n) {
    for (lim=1,l=0; lim<=n; lim<<=1,++l) ;
    REP(i,0,lim-1) R[i]=(R[i>>1]>>1)|((i&1)<<(l-1));
}

void NTT(int *J, int tp=1) {
    REP(i,0,lim-1) if (i<R[i]) swap(J[i],J[R[i]]);
    for (int j=1; j<lim; j<<=1) {
        ll T = qpow(tp==1?G:Gi,(P-1)/(j<<1));
        for (int k=0; k<lim; k+=j<<1) {
            ll t = 1;
            for (int l=0; l<j; ++l,t=t*T%P) {
                int y = t*J[k+j+l]%P;
                J[k+j+l] = (J[k+l]-y)%P;
                J[k+l] = (J[k+l]+y)%P;
            }
        }
    }
    if (tp==-1) {
        ll inv = qpow(lim, P-2);
        REP(i,0,lim-1) J[i]=(ll)inv*J[i]%P;
    }
}

void mul(int *a, int *b, int *c) {
    REP(i,0,lim-1) A[i]=a[i],B[i]=b[i];
    NTT(A),NTT(B);
    REP(i,0,lim-1) c[i]=(ll)A[i]*B[i]%P;
    NTT(c,-1);
    REP(i,0,m-2) c[i]=(c[i]+c[i+m-1])%P,c[i+m-1]=0;
}

void solve(int n) {
    if (n==1) memcpy(a,f,sizeof f);
    else {
        solve(n/2);
        mul(a,a,a);
        if (n&1) mul(a,f,a);
    }
}

int main() {
    scanf("%d%d%d%d", &n, &m, &x, &s);
    init(m*2);
    g[0] = 1, g[1] = rt(m);
    REP(i,2,m-2) g[i]=(ll)g[i-1]*g[1]%m;
    REP(i,0,m-2) no[g[i]]=i;
    REP(i,1,s) {
        int t;
        scanf("%d", &t);
        if (t) f[no[t]]=1;
    }
    solve(n);
    int ans = a[no[x]];
    if (ans<0) ans += P;
    printf("%d\n", ans);
}

 

posted @ 2019-06-15 18:24  uid001  阅读(183)  评论(0编辑  收藏  举报