BZOJ3992[SDOI2015]序列统计
题目链接
解析
头一回知道原根还可以这么考……
不难想到递推的做法\(dp[i][j]\)表示长度为\(i\),乘积为\(j\)的答案,那么\(dp[i][j \cdot a[i] \ mod \ M] += dp[i - 1][j]\),
首先我们发现\(0\)可以直接丢掉,因为包含\(0\)的序列对答案不产生任何贡献,然后题目说\(M\)是个质数,如果设\(G\)是\(M\)的原根,那么\([1, M - 1]\)和\(G^1 \ mod \ M, G^2 \ mod \ M, ..., G^{M - 1} \ mod \ M\)一一对应,那么可以用\(G\)头上的指数来代替\([1, M - 1]\)中的一个数,设\(dp[i][j]\)表示长度为\(i\),乘积为\(G^j \ mod \ M\)的答案,则\(dp[i][(j + k)] += dp[i - 1][j]\)
用\(f[j]\)表示\(dp[i][j]\),\(g[j]\)表示\(dp[i - 1][j]\),\(h[j]\)表示给定的集合中是否存在\(G^j \ mod \ M\)这个数,那么
\[f[j] = \sum_{k = 1}^{M - 1} g[k] \cdot h[j - k]
\]
这是一个卷积
发现初始只有\(f[0] = 1\),所以只需对\(h\)快速幂卷积即可
当然上面都是不考虑指数对\(M - 1\)取模的时候,考虑取模就每次卷积后把\(f[i](i > M - 1)\)加到\(f[i \ mod \ (M - 1)]\)的位置就可以了
代码
不知道为什么好像我的\(FFT\)和\(NTT\)都自带大常数……耗时比第一页最慢的都多了一倍加\(1s\)……
#include <cstring>
#include <iostream>
#include <cstdio>
#include <vector>
#define MAXN 8010
typedef long long LL;
const LL mod = 1004535809ll;
int qpower(int, int, int);
void pre_prime();
void divide(int, std::vector<int> &);
int get_g(int);
void pre_rev(int);
void NTT(int *, int, int);
void mul(int *, int *, int);
int N, M, X, SZ, n, g, pre[MAXN], f[MAXN << 2], ans[MAXN << 2], rev[MAXN << 2];
std::vector<int> prime, dvd;
inline void inc(int &x, int y) { x += y; if (x >= mod) x -= mod; }
inline void dec(int &x, int y) { x -= y; if (x < 0) x += mod; }
inline int add(int x, int y) { x += y; return x >= mod ? x - mod : x; }
inline int sub(int x, int y) { x -= y; return x < 0 ? x + mod : x; }
int main() {
pre_prime();
scanf("%d%d%d%d", &N, &M, &X, &SZ);
g = get_g(M);
for (int i = 1, j = g; i < M; ++i, j = (LL)j * g % M) pre[j] = i;
for (int i = 0; i < SZ; ++i) {
int t; scanf("%d", &t);
if(t) ++f[pre[t]];
}
//debug
//printf("%d\n", g);
//for (int i = 0; i < M; ++i) printf("%d ", pre[i]);
//puts("");
ans[0] = 1;
while ((1 << n) < (M << 1)) ++n;
pre_rev(n);
while (N) {
if (N & 1) mul(ans, f, n);
mul(f, f, n);
N >>= 1;
}
printf("%d\n", ans[pre[X]]);
return 0;
}
void pre_prime() {
static bool isn_prime[MAXN];
for (int i = 2; i < MAXN; ++i) {
if (!isn_prime[i]) prime.push_back(i);
for (int j = 0; j < prime.size(), i * prime[j] < MAXN; ++j) {
isn_prime[i * prime[j]] = 0;
if (i % prime[j] == 0) break;
}
}
}
int get_g(int x) {
divide(x - 1, dvd);
for (int i = 2; i < x; ++i)
if (qpower(i, x - 1, x) == 1) {
bool flag = 1;
for (int j = 0; j < dvd.size(); ++j)
if (qpower(i, (x - 1) / dvd[j], x) == 1) { flag = 0; break; }
if (flag) return i;
}
}
void divide(int x, std::vector<int> &res) {
for (int i = 0; i < prime.size(); ++i)
if (x % prime[i] == 0) {
res.push_back(prime[i]);
while (x % prime[i] == 0) x /= prime[i];
}
}
int qpower(int x, int y, int p) {
int res = 1;
while (y) {
if (y & 1) res = (LL)res * x % p;
x = (LL)x * x % p; y >>= 1;
}
return res;
}
void NTT(int *arr, int sz, int tp) {
for (int i = 0; i < (1 << sz); ++i)
if (rev[i] > i) std::swap(arr[i], arr[rev[i]]);
for (int len = 2, half = 1; len <= (1 << sz); len <<= 1, half <<= 1) {
int wn = qpower(3, (mod - 1) / len, mod);
if (tp == -1) wn = qpower(wn, mod - 2, mod);
for (int i = 0; i < (1 << sz); i += len)
for (int j = 0, w = 1; j < half; ++j, w = (LL)w * wn % mod) {
int x = arr[i + j], y = (LL)arr[i + j + half] * w % mod;
inc(arr[i + j], y); dec(arr[i + j + half] = x, y);
}
}
if (tp == -1) {
int inv = qpower(1 << sz, mod - 2, mod);
for (int i = 0; i < (1 << sz); ++i) arr[i] = (LL)arr[i] * inv % mod;
}
}
void mul(int *a, int *b, int sz) {
static int tmp[MAXN << 2];
for (int i = 0; i < (1 << sz); ++i) tmp[i] = b[i];
NTT(a, sz, 1); NTT(tmp, sz, 1);
for (int i = 0; i < (1 << sz); ++i) a[i] = (LL)a[i] * tmp[i] % mod;
NTT(a, sz, -1);
for (int i = 1; i < M; ++i) inc(a[i], a[i + M - 1]), a[i + M - 1] = 0;
}
void pre_rev(int sz) {
for (int i = 0; i < (1 << sz); ++i)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << sz - 1));
}
//Rhein_E