P3321 [SDOI2015]序列统计 (NTT快速幂)
Code:
#include <map> #include <set> #include <array> #include <queue> #include <stack> #include <cmath> #include <vector> #include <cstdio> #include <cstring> #include <sstream> #include <iostream> #include <stdlib.h> #include <algorithm> #include <unordered_map> using namespace std; typedef long long ll; typedef pair<int, int> PII; #define sd(a) scanf("%d", &a) #define sdd(a, b) scanf("%d%d", &a, &b) #define slld(a) scanf("%lld", &a) #define slldd(a, b) scanf("%lld%lld", &a, &b) const int N = 3e4 + 10; const int M = 1e6 + 20; const int mod = 1004535809; const int INF = 0x3f3f3f3f; const double PI = acos(-1.0); int n, m; int rev[N]; ll qmi(ll a, ll b, ll p){ ll res = 1; while(b){ if(b & 1) res = res * a % p; a = a * a % p; b >>= 1; } return res; } void change(ll y[], int len){ for(int i = 0; i < len; i ++){ rev[i] = rev[i >> 1] >> 1; if(i & 1) rev[i] |= (len >> 1); } for(int i = 0; i < len; i ++){ if(i < rev[i]) swap(y[i], y[rev[i]]); } } void ntt(ll y[], int len, int on){ change(y, len); for(int h = 2; h <= len; h <<= 1){ ll gn = qmi(3, (mod - 1) / h, mod); if(on == -1) gn = qmi(gn, mod - 2, mod); for(int j = 0; j < len; j += h){ ll g = 1; for(int k = j; k < j + h / 2; k ++){ ll u = y[k]; ll t = g * y[k + h / 2] % mod; y[k] = (u + t) % mod; y[k + h / 2] = (u - t + mod) % mod; g = g * gn % mod; } } } if(on == -1){ ll inv = qmi(len, mod - 2, mod); for(int i = 0; i < len; i ++){ y[i] = y[i] * inv % mod; } } } ll mid[N]; ll a[N], res[N]; void ntt_qmi(ll b, int len){ int mm = m - 1; while(b){ ntt(a, len, 1); if(b & 1){ ntt(res, len, 1); for(int i = 0; i < len; i ++){ res[i] = res[i] * a[i] % mod; } ntt(res, len, -1); for(int i = mm; i < len; i ++){ res[i % mm] = (res[i % mm] + res[i]) % mod; res[i] = 0; } } for(int i = 0; i < len; i ++){ a[i] = a[i] * a[i] % mod; } ntt(a, len, -1); for(int i = mm; i < len; i ++){ a[i % mm] = (a[i % mm] + a[i]) % mod; a[i] = 0; } b >>= 1; } } bool st[N]; int primes[N], cnt = 0; void get(int n){ for(int i = 2; i <= n; i ++){ if(!st[i]){ primes[cnt ++] = i; } for(int j = 0; primes[j] <= n / i; j ++){ st[i * primes[j]] = true; if(i % primes[j] == 0){ break; } } } } int get_a(){ get(N - 10); ll phi = m; ll mm = m; for(int i = 2; i <= mm / i; i ++){ if(mm % i == 0){ phi = phi * (i - 1) / i; while(mm % i == 0){ mm /= i; } } } if(mm > 1) phi = phi * (mm - 1) / mm; vector <int> d; for(int i = 2; i <= phi / i; i ++){ if(phi % i == 0){ d.push_back(i); if(i != phi / i) d.push_back(phi / i); } } for(int i = 0; i < cnt; i ++){ int p = primes[i]; bool flag = true; for(int j = 0; j < d.size(); j ++){ if(qmi(p, phi / d[j], m) == 1){ flag = false; break; } } if(flag == true) return p; } } int x, num, _x; ll f_log[N], vis[N]; void solve(){ cin >> n >> m >> x >> num; int _a = get_a(); for(int i = 0, j = 1; i < m - 1; i ++, j = (ll)j * _a % m){ f_log[j] = i; } for(int i = 0; i < m; i ++) vis[i] = 0; for(int i = 0; i < num; i ++){ cin >> _x; if(_x != 0) vis[f_log[_x]] ++; } res[0] = 1; for(int i = 0; i < m; i ++){ a[i] = vis[i]; } int len = 1; while(len < m + m - 1) len <<= 1; ntt_qmi(n, len); cout << res[f_log[x]] << "\n"; } int main() { #ifdef ONLINE_JUDGE #else freopen("/home/jungu/code/in.txt", "r", stdin); // freopen("/home/jungu/桌面/11.21/2/in9.txt", "r", stdin); #endif ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); int T = 1; // sd(T); // cin >> T; while (T--) { solve(); } return 0; }