[BZOJ 3992] [SDOI 2015] 序列统计
Description
Solution
〖一〗
设 \(f[i][j]\) 表示前 \(i\) 个数的乘积在模 \(p\) 意义下等于 \(j\) 的方案数,有
\[f[i][j]=\sum_{k=0}^{p-1}f[i-1][k]\cdot h[j\cdot k^{-1}]
\]
其中 \(h[i]\) 表示 \(S\) 中模 \(p\) 等于 \(i\) 的元素个数。
〖二〗
设 \(g\) 为模数 \(p\) 的原根,根据原根的性质可知 \(g^1\cdots g^{p-1}\) 互不相同,设 \(f[i][j]\) 表示前 \(i\) 个数的乘积在模 \(p\) 意义下等于 \(g^j\) 的方案数,有
\[f[i][j]=\sum_{k=0}^{p-1}f[i-1][k]\cdot h[j-k]
\]
其中 \(h[i]\) 表示 \(S\) 中模 \(p\) 等于 \(g^i\) 的元素个数。
于是可以化成多项式的形式:
\[(h_0+h_1x+h_2x^2+\cdots+h_{p-1}x^{p-1})^{n-1}
\]
Code
#include <cmath>
#include <cstdio>
#include <algorithm>
const int N = 16390, P = 1004535809, G = 3, Gi = 334845270;
int n, m, x, y, s, g, nn, mm, vis[8002], R[N], h[N], a[N], L, inv;
int read() {
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return x;
}
int ksm(int a, int b, int p) {
int res = 1;
for (; b; b >>= 1, a = 1LL * a * a % p)
if (b & 1) res = 1LL * res * a % p;
return res;
}
int getroot(int x) {
if (x == 2) return 1;
int m = sqrt(x - 1);
for (int i = 2; ; ++i) {
bool ok = 1;
for (int j = 2; j <= m; ++j)
if (ksm(i, (x - 1) / j, x) == 1) { ok = 0; break; }
if (ok) return i;
}
}
void NTT(int *A, int f) {
for (int i = 0; i < nn; ++i) if (i < R[i]) std::swap(A[i], A[R[i]]);
for (int i = 1; i < nn; i <<= 1) {
int wn = ksm(f == 1 ? G : Gi, (P - 1) / (i << 1), P);
for (int j = 0, r = i << 1; j < nn; j += r) {
int w = 1;
for (int k = 0; k < i; ++k, w = 1LL * w * wn % P) {
int x = A[j + k], y = 1LL * w * A[i + j + k] % P;
A[j + k] = (x + y) % P, A[i + j + k] = (x - y + P) % P;
}
}
}
}
void mul(int *a, int *b) {
int c[N] = {}, d[N] = {};
for (int i = 1; i < m; ++i) c[i] = a[i], d[i] = b[i];
NTT(c, 1), NTT(d, 1);
for (int i = 0; i < nn; ++i) a[i] = 1LL * c[i] * d[i] % P;
NTT(a, -1);
for (int i = 0; i <= mm; ++i) a[i] = 1LL * a[i] * inv % P;
for (int i = m; i <= mm; ++i) {
a[i - m + 1] += a[i];
if (a[i - m + 1] >= P) a[i - m + 1] -= P;
a[i] = 0;
}
}
void fastpow(int b) {
for (; b; b >>= 1, mul(a, a))
if (b & 1) mul(h, a);
}
int main() {
n = read(), m = read(), x = read(), s = read();
g = getroot(m);
for (int i = 1; i <= s; ++i) vis[read()] = 1;
for (int i = 1, t; i < m; ++i) {
t = ksm(g, i, m), h[i] = a[i] = vis[t];
if (x == t) y = i;
}
mm = (m - 1) << 1;
for (nn = 1; nn <= mm; nn <<= 1) ++L;
inv = ksm(nn, P - 2, P);
for (int i = 0; i < nn; ++i) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
fastpow(n - 1);
printf("%d\n", h[y]);
return 0;
}