bzoj3992: [SDOI2015]序列统计

传送门:http://www.lydsy.com:808/JudgeOnline/problem.php?id=3992

思路:M是一个质数,问题又是求乘积,于是我们就可以想到利用M的原根g把问题变成求和(我怎么想不到啊。。。)

根据原根的性质,我们可以把1到M-1中的数i表示为(g^b[i])%M,且指数互不相同

那么X就可以表示成(g^b[x])%M

问题就转化为:然后问题转化成了在序列b中,选出n个数(一个数可以取多次),且它们的和s满足s=b[x]

这个问题我们可以用母函数+NTT解决,答案就是多项式的b[x]次项的系数

因为n很大,所以再套一个快速幂即可

#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int mod=(479<<21)+1,G=3,maxn=17000;
int n,k,sum,m,inv_G,inv_N,T[maxn],vis[maxn],tim,N,rev[maxn],pos[maxn],root;

int qpow(int a,int b){
	int res=1;
	for (;b;b>>=1){
		if (b&1) res=1ll*res*a%mod;
		a=1ll*a*a%mod;
	}
	return res;
}

bool check(int x){
	int now=1;++tim;
	for (int i=1;i<n;i++,now=now*x%n){
		if (vis[now]==tim) return 0;
		vis[now]=tim;
	}
	return 1;
}

int rever(int x){int res=0,len=N;while (len--) res<<=1,res^=(x&1),x>>=1;return res;}
int findroot(){for (int i=2;i<=n;i++) if (check(i)) return i;}

struct DFT{
	int a[maxn];
	void ntt(int op){
		for (int i=0;i<N;i++) if (rev[i]>i) swap(a[rev[i]],a[i]);
		int g=op==1?G:inv_G;
		for (int sz=2;sz<=N;sz<<=1){
			int t=qpow(g,(mod-1)/sz);
			for (int bg=0;bg<N;bg+=sz)
				for (int po=bg,w=1;po<bg+(sz>>1);po++){
					int x=a[po],y=1ll*a[po+(sz>>1)]*w%mod;
					a[po]=(x+y)%mod,a[po+(sz>>1)]=(x-y+mod)%mod;
					w=1ll*w*t%mod;
				}
		}
		if (op==-1) for (int i=0;i<N;i++) a[i]=1LL*a[i]*inv_N%mod;
	}
}a,b;

void qpow(){
	b.a[0]=1;
	for (;k;k>>=1){
		a.ntt(1);
		if (k&1){
			b.ntt(1);for (int i=0;i<N;i++) b.a[i]=1ll*b.a[i]*a.a[i]%mod;
			b.ntt(-1);for (int i=N-1;i>=n-1;i--) b.a[i-n+1]=(b.a[i-n+1]+b.a[i])%mod,b.a[i]=0;
			//因为是在mod m(代码里的n)的条件下进行的,所以要把和超过m的后半部分的答案加到前半部分去,而不是简单的清空
		}
		for (int i=0;i<N;i++) a.a[i]=1ll*a.a[i]*a.a[i]%mod;a.ntt(-1);
		for (int i=N-1;i>=n-1;i--) a.a[i-n+1]=(a.a[i-n+1]+a.a[i])%mod,a.a[i]=0;
	}
}

int main(){
	inv_G=qpow(G,mod-2);
	scanf("%d%d%d%d",&k,&n,&sum,&m);
	for (int i=1;i<=m;i++) scanf("%d",&T[i]);
	N=(int)ceil(log2(n))+1;
	for (int i=0;i<(1<<N);i++) rev[i]=rever(i);
	N=1<<N,inv_N=qpow(N,mod-2),root=findroot();
	for (int i=0,res=1;i<n-1;i++) pos[res]=i,res=res*root%n;

	for (int i=1;i<=m;i++) if (T[i]) a.a[pos[T[i]]]++;
	qpow(),printf("%d\n",b.a[pos[sum]]);
	return 0;
}


posted @ 2015-07-25 09:23  orzpps  阅读(134)  评论(0编辑  收藏  举报