洛谷 P3321 [SDOI2015] 序列统计
感觉挺综合的一道题。
考虑朴素 dp,\(\forall x \in S, f_{i + 1, jx \bmod m} \gets f_{i,j}\)。复杂度 \(O(nm^2)\)。显然可以矩乘优化至 \(O(m^3 \log n)\),但是不能通过。
如果转移式中是加法而不是乘法,那很容易卷积优化。接下来是 一个很重要的套路:化乘为加。 实数范围内可以取对数,正整数范围内,考虑取 \(m\) 的原根 \(g\),因为 \(g\) 满足 \(g^0, g^1, ..., g^{m-2}\) 两两不同,所以可以把 \(1 \sim m - 1\) 的数映射到指数。
接下来求这个多项式的 \(n\) 次幂即可。注意每次倍增时要把后面的部分加到前面去,因为是在模 \(m\) 意义下。
code
// Problem: P3321 [SDOI2015]序列统计
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3321
// Memory Limit: 125 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 32100;
const ll mod = 1004535809, G = 3;
inline ll qpow(ll b, ll p, const ll &mod) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, m, K, X, p, a[maxn], r[maxn], b[maxn], tot, c[maxn];
typedef vector<ll> poly;
inline poly NTT(poly a, int op) {
int n = (int)a.size();
for (int i = 0; i < n; ++i) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < n; k <<= 1) {
ll wn = qpow(op == 1 ? G : qpow(G, mod - 2, mod), (mod - 1) / (k << 1), mod);
for (int i = 0; i < n; i += (k << 1)) {
ll w = 1;
for (int j = 0; j < k; ++j, w = w * wn % mod) {
ll x = a[i + j], y = w * a[i + j + k] % mod;
a[i + j] = (x + y) % mod;
a[i + j + k] = (x - y + mod) % mod;
}
}
}
return a;
}
inline poly operator * (poly a, poly b) {
a = NTT(a, 1);
b = NTT(b, 1);
int n = (int)a.size();
for (int i = 0; i < n; ++i) {
a[i] = a[i] * b[i] % mod;
}
a = NTT(a, -1);
ll inv = qpow(n, mod - 2, mod);
for (int i = 0; i < n; ++i) {
a[i] = a[i] * inv % mod;
}
return a;
}
inline bool check(ll x) {
if (qpow(x, p, m) != 1) {
return 0;
}
for (int i = 1; i <= tot; ++i) {
if (qpow(x, p / b[i], m) == 1) {
return 0;
}
}
return 1;
}
inline poly qpow(poly a, ll m, ll p) {
int n = (int)a.size();
poly res(n);
res[0] = 1;
while (p) {
if (p & 1) {
res = res * a;
for (int i = m + 1; i < n; ++i) {
// 对 m + 1 取模
res[i % (m + 1)] = (res[i % (m + 1)] + res[i]) % mod;
res[i] = 0;
}
}
a = a * a;
for (int i = m + 1; i < n; ++i) {
a[i % (m + 1)] = (a[i % (m + 1)] + a[i]) % mod;
a[i] = 0;
}
p >>= 1;
}
return res;
}
void solve() {
scanf("%lld%lld%lld%lld", &n, &m, &X, &K);
p = m - 1;
ll x = p;
for (ll i = 2; i * i <= x; ++i) {
if (x % i == 0) {
b[++tot] = i;
while (x % i == 0) {
x /= i;
}
}
}
if (x > 1) {
b[++tot] = x;
}
ll g = -1;
for (int i = 1; i < m; ++i) {
if (check(i)) {
g = i;
break;
}
}
for (ll i = 0, x = 1; i <= m - 2; ++i, x = x * g % m) {
c[x] = i;
}
int k = 0;
while ((1 << k) <= m * 2) {
++k;
}
for (int i = 1; i < (1 << k); ++i) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
}
while (K--) {
ll x;
scanf("%lld", &x);
if (x % m) {
a[c[x]] = 1;
}
}
poly A;
for (int i = 0; i < (1 << k); ++i) {
A.pb(a[i]);
}
poly B = qpow(A, m - 2, n);
printf("%lld\n", B[c[X]]);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}