洛谷 P3321 [SDOI2015] 序列统计

洛谷传送门

感觉挺综合的一道题。

考虑朴素 dp,\(\forall x \in S, f_{i + 1, jx \bmod m} \gets f_{i,j}\)。复杂度 \(O(nm^2)\)。显然可以矩乘优化至 \(O(m^3 \log n)\),但是不能通过。

如果转移式中是加法而不是乘法,那很容易卷积优化。接下来是 一个很重要的套路:化乘为加。 实数范围内可以取对数,正整数范围内,考虑取 \(m\) 的原根 \(g\),因为 \(g\) 满足 \(g^0, g^1, ..., g^{m-2}\) 两两不同,所以可以把 \(1 \sim m - 1\) 的数映射到指数。

接下来求这个多项式的 \(n\) 次幂即可。注意每次倍增时要把后面的部分加到前面去,因为是在模 \(m\) 意义下。

code
// Problem: P3321 [SDOI2015]序列统计
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3321
// Memory Limit: 125 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 32100;
const ll mod = 1004535809, G = 3;

inline ll qpow(ll b, ll p, const ll &mod) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, m, K, X, p, a[maxn], r[maxn], b[maxn], tot, c[maxn];

typedef vector<ll> poly;

inline poly NTT(poly a, int op) {
	int n = (int)a.size();
	for (int i = 0; i < n; ++i) {
		if (i < r[i]) {
			swap(a[i], a[r[i]]);
		}
	}
	for (int k = 1; k < n; k <<= 1) {
		ll wn = qpow(op == 1 ? G : qpow(G, mod - 2, mod), (mod - 1) / (k << 1), mod);
		for (int i = 0; i < n; i += (k << 1)) {
			ll w = 1;
			for (int j = 0; j < k; ++j, w = w * wn % mod) {
				ll x = a[i + j], y = w * a[i + j + k] % mod;
				a[i + j] = (x + y) % mod;
				a[i + j + k] = (x - y + mod) % mod;
			}
		}
	}
	return a;
}

inline poly operator * (poly a, poly b) {
	a = NTT(a, 1);
	b = NTT(b, 1);
	int n = (int)a.size();
	for (int i = 0; i < n; ++i) {
		a[i] = a[i] * b[i] % mod;
	}
	a = NTT(a, -1);
	ll inv = qpow(n, mod - 2, mod);
	for (int i = 0; i < n; ++i) {
		a[i] = a[i] * inv % mod;
	}
	return a;
}

inline bool check(ll x) {
	if (qpow(x, p, m) != 1) {
		return 0;
	}
	for (int i = 1; i <= tot; ++i) {
		if (qpow(x, p / b[i], m) == 1) {
			return 0;
		}
	}
	return 1;
}

inline poly qpow(poly a, ll m, ll p) {
	int n = (int)a.size();
	poly res(n);
	res[0] = 1;
	while (p) {
		if (p & 1) {
			res = res * a;
			for (int i = m + 1; i < n; ++i) {
				// 对 m + 1 取模
				res[i % (m + 1)] = (res[i % (m + 1)] + res[i]) % mod;
				res[i] = 0;
			}
		}
		a = a * a;
		for (int i = m + 1; i < n; ++i) {
			a[i % (m + 1)] = (a[i % (m + 1)] + a[i]) % mod;
			a[i] = 0;
		}
		p >>= 1;
	}
	return res;
}

void solve() {
	scanf("%lld%lld%lld%lld", &n, &m, &X, &K);
	p = m - 1;
	ll x = p;
	for (ll i = 2; i * i <= x; ++i) {
		if (x % i == 0) {
			b[++tot] = i;
			while (x % i == 0) {
				x /= i;
			}
		}
	}
	if (x > 1) {
		b[++tot] = x;
	}
	ll g = -1;
	for (int i = 1; i < m; ++i) {
		if (check(i)) {
			g = i;
			break;
		}
	}
	for (ll i = 0, x = 1; i <= m - 2; ++i, x = x * g % m) {
		c[x] = i;
	}
	int k = 0;
	while ((1 << k) <= m * 2) {
		++k;
	}
	for (int i = 1; i < (1 << k); ++i) {
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
	}
	while (K--) {
		ll x;
		scanf("%lld", &x);
		if (x % m) {
			a[c[x]] = 1;
		}
	}
	poly A;
	for (int i = 0; i < (1 << k); ++i) {
		A.pb(a[i]);
	}
	poly B = qpow(A, m - 2, n);
	printf("%lld\n", B[c[X]]);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}

posted @ 2023-05-11 17:53  zltzlt  阅读(19)  评论(0编辑  收藏  举报