P3321 [SDOI2015]序列统计 (NTT快速幂)

Code:

#include <map>
#include <set>
#include <array>
#include <queue>
#include <stack>
#include <cmath>
#include <vector>
#include <cstdio>
#include <cstring>
#include <sstream>
#include <iostream>
#include <stdlib.h>
#include <algorithm>
#include <unordered_map>

using namespace std;

typedef long long ll;
typedef pair<int, int> PII;

#define sd(a) scanf("%d", &a)
#define sdd(a, b) scanf("%d%d", &a, &b)
#define slld(a) scanf("%lld", &a)
#define slldd(a, b) scanf("%lld%lld", &a, &b)

const int N = 3e4 + 10;
const int M = 1e6 + 20;
const int mod = 1004535809;
const int INF = 0x3f3f3f3f;
const double PI = acos(-1.0);

int n, m;
int rev[N];

ll qmi(ll a, ll b, ll p){
    ll res = 1;
    while(b){
        if(b & 1) res = res * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return res;
}

void change(ll y[], int len){
    for(int i = 0; i < len; i ++){
        rev[i] = rev[i >> 1] >> 1;
        if(i & 1) rev[i] |= (len >> 1);
    }

    for(int i = 0; i < len; i ++){
        if(i < rev[i]) swap(y[i], y[rev[i]]);
    }
}

void ntt(ll y[], int len, int on){
    change(y, len);

    for(int h = 2; h <= len; h <<= 1){
        ll gn = qmi(3, (mod - 1) / h, mod);
        if(on == -1) gn = qmi(gn, mod - 2, mod);
        for(int j = 0; j < len; j += h){
            ll g = 1;
            for(int k = j; k < j + h / 2; k ++){
                ll u = y[k];
                ll t = g * y[k + h / 2] % mod;
                y[k] = (u + t) % mod;
                y[k + h / 2] = (u - t + mod) % mod;
                g = g * gn % mod;
            }
        }
    }

    if(on == -1){
        ll inv = qmi(len, mod - 2, mod);
        for(int i = 0; i < len; i ++){
            y[i] = y[i] * inv % mod;
        }
    }
}

ll mid[N];
ll a[N], res[N];

void ntt_qmi(ll b, int len){
    int mm = m - 1;
    while(b){
        ntt(a, len, 1);
        if(b & 1){
            ntt(res, len, 1);
            for(int i = 0; i < len; i ++){
                res[i] = res[i] * a[i] % mod;
            }
            ntt(res, len, -1);

            for(int i = mm; i < len; i ++){
                res[i % mm] = (res[i % mm] + res[i]) % mod;
                res[i] = 0;
            }
        }
        for(int i = 0; i < len; i ++){
            a[i] = a[i] * a[i] % mod;
        }
        ntt(a, len, -1);
        for(int i = mm; i < len; i ++){
            a[i % mm] = (a[i % mm] + a[i]) % mod;
            a[i] = 0;
        }
        b >>= 1;
    }
}

bool st[N];
int primes[N], cnt = 0;

void get(int n){
    for(int i = 2; i <= n; i ++){
        if(!st[i]){
            primes[cnt ++] = i;
        }
        for(int j = 0; primes[j] <= n / i; j ++){
            st[i * primes[j]] = true;
            if(i % primes[j] == 0){
                break;
            }
        }
    }
}

int get_a(){
    get(N - 10);

    ll phi = m;
    ll mm = m;
    for(int i = 2; i <= mm / i; i ++){
        if(mm % i == 0){
            phi = phi * (i - 1) / i;
            while(mm % i == 0){
                mm /= i;
            }
        }
    }
    if(mm > 1) phi = phi * (mm - 1) / mm;

    vector <int> d;
    for(int i = 2; i <= phi / i; i ++){
        if(phi % i == 0){
            d.push_back(i);
            if(i != phi / i) d.push_back(phi / i);
        }
    }

    for(int i = 0; i < cnt; i ++){
        int p = primes[i];
        bool flag = true;
        for(int j = 0; j < d.size(); j ++){
            if(qmi(p, phi / d[j], m) == 1){
                flag = false;
                break;
            }
        }
        if(flag == true) return p;
    }
}

int x, num, _x;
ll f_log[N], vis[N];

void solve(){
    
    cin >> n >> m >> x >> num;

    int _a = get_a();
    for(int i = 0, j = 1; i < m - 1; i ++, j = (ll)j * _a % m){
        f_log[j] = i;
    }

    for(int i = 0; i < m; i ++) vis[i] = 0;

    for(int i = 0; i < num; i ++){
        cin >> _x;
        if(_x != 0) vis[f_log[_x]] ++;
    }

    res[0] = 1;
    for(int i = 0; i < m; i ++){
        a[i] = vis[i];
    }

    int len = 1;
    while(len < m + m - 1) len <<= 1;

    ntt_qmi(n, len);

    cout << res[f_log[x]] << "\n";

}

int main()
{
#ifdef ONLINE_JUDGE
#else
    freopen("/home/jungu/code/in.txt", "r", stdin);
    // freopen("/home/jungu/桌面/11.21/2/in9.txt", "r", stdin);
#endif
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    int T = 1;
    // sd(T);
    // cin >> T;
    while (T--)
    {
        solve();
    }

    return 0;
}

 

posted @ 2021-02-10 22:33  君顾  阅读(59)  评论(0编辑  收藏  举报