[SDOI 2015]序列统计

Description

题库链接

给出集合 $S$ ,元素都是小于 $M$ 的非负整数。问能够生成出多少个长度为 $N$ 的数列 $A$ ,数列中的每个数都属于集合 $S$ ,并且

$$\prod_{i=1}^N A_i\equiv x \pmod{M}$$

答案对 $1004535809$ 取模。

$1\leq N\leq 10^9,3\leq M\leq 8000, M 为质数,0\leq x\leq M-1$

Solution

显然能够得到 $DP$ 的解法:令 $f_{i,j}$ 为生成序列长度为 $i$ 时,乘积在模 $M$ 意义下为 $j$ 的方案数。

显然 $f_{i,j}\rightarrow f_{i+1,(j\times w)\mod M},w\in S$ 。

但 $n\leq 10^9$ 显然不能递推。考虑优化。

由于乘法不太好搞,我们试着换种思路,我们不妨将集合内数取 $\log$ 。那么 $f_{i,\log j}\rightarrow f_{i+1,\log j+\log w},w\in S$ 。

但实数域上确实不好做,考虑取离散对数。由费马小定理,它是以 $M-1$ 为周期的,那么只要 $\text{NTT}$ 优化,加上快速幂。对模意义外的数讨论即可。

Code

#include <bits/stdc++.h>
using namespace std;
const int yzh = 1004535809;
const int N = 8000*4;

int n, m, x, s, G, lg[N+5], a, len, L, R[N+5];
int A[N+5];

int quick_pow(int a, int b, int yzh) {
    int ans = 1;
    while (b) {
    if (b&1) ans = 1ll*a*ans%yzh;
    b >>= 1, a = 1ll*a*a%yzh;
    }
    return ans;
}
void get_G() {
    int prime[N+5], tot = 0, x = m-1;
    for (int i = 2, lim = sqrt(x)+1; i <= lim; i++)
    if (x%i == 0) {
        prime[++tot] = i;
        while (x%i == 0) x /= i;
    }
    if (x != 1) prime[++tot] = x;
    for (int i = 2; true; i++) {
    int flag = 1;
    for (int j = 1; j <= tot; j++)
        if (quick_pow(i, (m-1)/prime[j], m) == 1) {
        flag = 0; break;
        }
    if (flag == 1) {G = i; break; }
    }
    for (int i = 1, g = G; i < m; i++, g = 1ll*g*G%m) lg[g] = i;
}
void NTT(int *A, int o) {
    for (int i = 0; i < len; i++) if (i < R[i]) swap(A[i], A[R[i]]);
    for (int i = 1; i < len; i <<= 1) {
    int gn = quick_pow(3, (yzh-1)/(i<<1), yzh), x, y;
    if (o == -1) gn = quick_pow(gn, yzh-2, yzh);
    for (int j = 0; j < len; j += (i<<1)) {
        int g = 1;
        for (int k = 0; k < i; k++, g = 1ll*g*gn%yzh) {
        x = A[j+k], y = 1ll*g*A[j+k+i]%yzh;
        A[j+k] = (x+y)%yzh, A[j+k+i] = (x-y+yzh)%yzh;
        }
    }
    }
    if (o == 1) return;
    for (int i = 0, inv = quick_pow(len, yzh-2, yzh); i < len; i++)
    A[i] = 1ll*A[i]*inv%yzh;
    for (int i = m; i < len; i++) (A[i%(m-1) ? i%(m-1) : m-1] += A[i]) %= yzh, A[i] = 0;
}
void NTTpow(int *A, int b) {
    int ans[N+5] = {0}; ans[0] = 1;
    while (b) {
    NTT(A, 1);
    if (b&1) {
        NTT(ans, 1);
        for (int i = 0; i < len; i++) ans[i] = 1ll*ans[i]*A[i]%yzh;
        NTT(ans, -1);
    }
    for (int i = 0; i < len; i++) A[i] = 1ll*A[i]*A[i]%yzh;
    NTT(A, -1); b >>= 1;
    }
    for (int i = 0; i < len; i++) A[i] = ans[i];
}
void work() {
    scanf("%d%d%d%d", &n, &m, &x, &s); get_G();
    for (int i = 1; i <= s; i++) {scanf("%d", &a); ++A[lg[a]]; }
    A[0] = 0;
    for (len = 1; len <= (m<<1); len <<= 1) ++L;
    for (int i = 0; i < len; i++) R[i] = (R[i>>1]>>1)|((i&1)<<(L-1));
    NTTpow(A, n); printf("%d\n", A[lg[x]]);
}
int main() {work(); return 0; }
posted @ 2018-04-10 18:29  NaVi_Awson  阅读(196)  评论(0编辑  收藏  举报