【bzoj3992】[SDOI2015]序列统计 原根+NTT
题目描述
求长度为 $n$ 的序列,每个数都是 $|S|$ 中的某一个,所有数的乘积模 $m$ 等于 $x$ 的序列数目模1004535809的值。
输入
一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。
第二行,|S|个整数,表示集合S中的所有元素。
1<=N<=10^9,3<=M<=8000,M为质数
1<=x<=M-1,输入数据保证集合S中元素不重复
输出
一行,一个整数,表示你求出的种类数mod 1004535809的值。
样例输入
4 3 1 2
1 2
样例输出
8
题解
原根+NTT
如果条件是和模 $m$ 等于 $x$ ,那么很明显就是一道NTT裸题。维护S集合的生成函数在模 $x^m$ 意义下的 $n$ 次幂即可。
然而本题的条件是乘积。可以求出 $m$ 的原根,对每个数取指标,那么原数相乘就变为指标相加,使用NTT快速幂即可。
求原根的过程可以直接暴力。
注意 $|S|$ 集合中的数可能有0,0是没有指标的。由于 $x\neq 0$ ,因此出现0时无意义,直接忽略这个数即可。
时间复杂度 $O(m\log^2n)$
#include <cstdio> #include <algorithm> #define N 16410 #define mod 1004535809 using namespace std; typedef long long ll; int m , s[N >> 1] , v[15] , tot , ind[N >> 1]; ll a[N] , ans[N]; inline ll pow(ll x , int y , ll m) { ll ans = 1; while(y) { if(y & 1) ans = ans * x % m; x = x * x % m , y >>= 1; } return ans; } int getroot() { int i , j , t = m - 1; for(i = 2 ; i * i <= t ; i ++ ) { if(t % i == 0) { v[++tot] = i; while(t % i == 0) t /= i; } } if(t != 1) v[++tot] = t; for(i = 2 ; i < m ; i ++ ) { for(j = 1 ; j <= tot ; j ++ ) if(pow(i , (m - 1) / v[j] , m) == 1) break; if(j > tot) return i; } return 0; } void ntt(ll *a , int n , int flag) { int i , j , k; for(k = i = 0 ; i < n ; i ++ ) { if(i > k) swap(a[i] , a[k]); for(j = (n >> 1) ; (k ^= j) < j ; j >>= 1); } for(k = 2 ; k <= n ; k <<= 1) { ll wn = pow(3 , (mod - 1) / k , mod); if(flag == -1) wn = pow(wn , mod - 2 , mod); for(i = 0 ; i < n ; i += k) { ll w = 1 , t; for(j = i ; j < i + (k >> 1) ; j ++ , w = w * wn % mod) t = w * a[j + (k >> 1)] % mod , a[j + (k >> 1)] = (a[j] - t + mod) % mod , a[j] = (a[j] + t) % mod; } } if(flag == -1) { k = pow(n , mod - 2 , mod); for(i = 0 ; i < n ; i ++ ) a[i] = a[i] * k % mod; for(i = m - 1 ; i < n ; i ++ ) a[i % (m - 1)] = (a[i % (m - 1)] + a[i]) % mod , a[i] = 0; } } void Pow(int y , int n) { int i; ans[0] = 1; while(y) { ntt(a , n , 1); if(y & 1) { ntt(ans , n , 1); for(i = 0 ; i < n ; i ++ ) ans[i] = ans[i] * a[i] % mod; ntt(ans , n , -1); } for(i = 0 ; i < n ; i ++ ) a[i] = a[i] * a[i] % mod; ntt(a , n , -1); y >>= 1; } } int main() { int n , x , k , i , r , t , len = 1; scanf("%d%d%d%d" , &n , &m , &x , &k); for(i = 1 ; i <= k ; i ++ ) scanf("%d" , &s[i]); r = getroot(); for(t = 1 , i = 0 ; i < m - 1 ; i ++ , t = t * r % m) ind[t] = i; for(i = 1 ; i <= k ; i ++ ) if(s[i]) a[ind[s[i]]] ++ ; while(len <= 2 * (m - 2)) len <<= 1; Pow(n , len); printf("%lld\n" , ans[ind[x]]); return 0; }