【bzoj3992】[SDOI2015]序列统计 原根+NTT

题目描述

求长度为 $n$ 的序列,每个数都是 $|S|$ 中的某一个,所有数的乘积模 $m$ 等于 $x$ 的序列数目模1004535809的值。

输入

一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。
第二行,|S|个整数,表示集合S中的所有元素。
1<=N<=10^9,3<=M<=8000,M为质数
1<=x<=M-1,输入数据保证集合S中元素不重复

输出

一行,一个整数,表示你求出的种类数mod 1004535809的值。

样例输入

4 3 1 2
1 2

样例输出

8


题解

原根+NTT

如果条件是和模 $m$ 等于 $x$ ,那么很明显就是一道NTT裸题。维护S集合的生成函数在模 $x^m$ 意义下的 $n$ 次幂即可。

然而本题的条件是乘积。可以求出 $m$ 的原根,对每个数取指标,那么原数相乘就变为指标相加,使用NTT快速幂即可。

求原根的过程可以直接暴力。

注意 $|S|$ 集合中的数可能有0,0是没有指标的。由于 $x\neq 0$ ,因此出现0时无意义,直接忽略这个数即可。

时间复杂度 $O(m\log^2n)$ 

#include <cstdio>
#include <algorithm>
#define N 16410
#define mod 1004535809
using namespace std;
typedef long long ll;
int m , s[N >> 1] , v[15] , tot , ind[N >> 1];
ll a[N] , ans[N];
inline ll pow(ll x , int y , ll m)
{
	ll ans = 1;
	while(y)
	{
		if(y & 1) ans = ans * x % m;
		x = x * x % m , y >>= 1;
	}
	return ans;
}
int getroot()
{
	int i , j , t = m - 1;
	for(i = 2 ; i * i <= t ; i ++ )
	{
		if(t % i == 0)
		{
			v[++tot] = i;
			while(t % i == 0) t /= i;
		}
	}
	if(t != 1) v[++tot] = t;
	for(i = 2 ; i < m ; i ++ )
	{
		for(j = 1 ; j <= tot ; j ++ )
			if(pow(i , (m - 1) / v[j] , m) == 1)
				break;
		if(j > tot) return i;
	}
	return 0;
}
void ntt(ll *a , int n , int flag)
{
	int i , j , k;
	for(k = i = 0 ; i < n ; i ++ )
	{
		if(i > k) swap(a[i] , a[k]);
		for(j = (n >> 1) ; (k ^= j) < j ; j >>= 1);
	}
	for(k = 2 ; k <= n ; k <<= 1)
	{
		ll wn = pow(3 , (mod - 1) / k , mod);
		if(flag == -1) wn = pow(wn , mod - 2 , mod);
		for(i = 0 ; i < n ; i += k)
		{
			ll w = 1 , t;
			for(j = i ; j < i + (k >> 1) ; j ++ , w = w * wn % mod)
				t = w * a[j + (k >> 1)] % mod , a[j + (k >> 1)] = (a[j] - t + mod) % mod , a[j] = (a[j] + t) % mod;
		}
	}
	if(flag == -1)
	{
		k = pow(n , mod - 2 , mod);
		for(i = 0 ; i < n ; i ++ ) a[i] = a[i] * k % mod;
		for(i = m - 1 ; i < n ; i ++ ) a[i % (m - 1)] = (a[i % (m - 1)] + a[i]) % mod , a[i] = 0;
	}
}
void Pow(int y , int n)
{
	int i;
	ans[0] = 1;
	while(y)
	{
		ntt(a , n , 1);
		if(y & 1)
		{
			ntt(ans , n , 1);
			for(i = 0 ; i < n ; i ++ ) ans[i] = ans[i] * a[i] % mod;
			ntt(ans , n , -1);
		}
		for(i = 0 ; i < n ; i ++ ) a[i] = a[i] * a[i] % mod;
		ntt(a , n , -1);
		y >>= 1;
	}
}
int main()
{
	int n , x , k , i , r , t , len = 1;
	scanf("%d%d%d%d" , &n , &m , &x , &k);
	for(i = 1 ; i <= k ; i ++ ) scanf("%d" , &s[i]);
	r = getroot();
	for(t = 1 , i = 0 ; i < m - 1 ; i ++ , t = t * r % m) ind[t] = i;
	for(i = 1 ; i <= k ; i ++ )
		if(s[i])
			a[ind[s[i]]] ++ ;
	while(len <= 2 * (m - 2)) len <<= 1;
	Pow(n , len);
	printf("%lld\n" , ans[ind[x]]);
	return 0;
}

 

 

posted @ 2018-03-21 10:10  GXZlegend  阅读(515)  评论(0编辑  收藏  举报