[SDOI2015]序列统计 - 生成函数 + NTT

Description

小C有一个集合 \(S\),里面的元素都是小于 \(m\) 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 \(n\) 的数列,数列中的每个数都属于集合 \(S\)

小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数 \(x\),求所有可以生成出的,且满足数列中所有数的乘积 \(\bmod \ m\) 的值等于 \(x\) 的不同的数列的有多少个。

小C认为,两个数列 \(A\)\(B\) 不同,当且仅当 \(\exists i \text{ s.t. } A_i \neq B_i\)。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案对 \(1004535809\) 取模的值就可以了。

Solution

首先考虑朴素算法的\(dp\)转移:\(f_i \times f_j \rightarrow f_{i\times j}\)

观察这个转移,如果转移是\(f_i \times f_j \rightarrow f_{i+j}\),那我们不就可以卷起来了吗?

乘法变成加法,想到了什么?对数。

题目中要求乘起来为\(x\)的方案数,那我们只需要对\(S_i\)\(x\)取关于\(M\)的原根的离散对数,然后问题就转化为了\(\log {S_i}\)加起来为\(\log x\)的方案数。

小于\(M\)的非负整数,除\(0\)外,一共有\(M-2\)个,值域为\([1,M-1]\)。因为\(g^0 \equiv g^{M-1}(\mod M)\),且\(g^0\)\(g^{M-2}\)两两不同,所以\(S_i\)\(\log S_i\)是一一对应的,\(\log S_{i}\)的范围是\([0,M-2]\)

接下来,定义生成函数

\[A(x)=\sum \limits_{i=0}^{\infty}a_ix^i \]

对于集合中的每个数\(S_i\),若\(S_i\)不为\(0\),令\(a_{\log S_i}=1\)

然后快速幂求出\(A^{N}(x)\),第\(\log x\)项就是答案。

最后需要注意的一点是,因为取离散对数后要求的是加起来\(\mod M-1\)\(\log x\),所以每次乘法后要把所有次数\(\mod M-1\)\(i\)的项统计到次数为\(i\)的项上。

Code

#include <bits/stdc++.h>
using namespace std;

inline int ty() {
	char ch = getchar(); int x = 0, f = 1;
	while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
	while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
	return x * f;
}

int ksm(int a, int b, int mod) {
	int ret = 1;
	for ( ; b; b >>= 1) {
		if (b & 1) ret = 1ll * ret * a % mod;
		a = 1ll * a * a % mod;
	}
	return ret;
}

const int P = 1004535809, G = 3, Gx = ksm(G, P - 2, P);
const int _ = 3e4 + 10;
int N, M, X, S, g;
int F[_], H[_], r[_], lim = 1, xx;
map<int, int> mp;

int root(const int p) {
	for (int i = 2; i <= p; ++i) {
		int x = p - 1;
		bool flag = true;
		for (int k = 2; k * k <= p - 1; ++k) if (!(x % k)) {
			if (ksm(i, (p - 1) / k, p) == 1) {
				flag = false;
				break;
			}
			while (!(x % k)) x /= k;
		}
		if (flag && (x == 1 || ksm(i, (p - 1) / x, p) != 1)) return i;
	}
}

void NTT(int *a, int op) {
	for (int i = 0; i < lim; ++i)
		if (i < r[i]) swap(a[i], a[r[i]]);
	for (int len = 2; len <= lim; len <<= 1) {
		int mid = len >> 1;
		int Wn = ksm(op == 1 ? G : Gx, (P - 1) / len, P);
		for (int i = 0; i < lim; i += len) {
			int w = 1;
			for (int j = 0; j < mid; ++j, w = 1ll * w * Wn % P) {
				int x = a[i + j], y = 1ll * w * a[i + j + mid] % P;
				a[i + j] = (x + y) % P;
				a[i + j + mid] = (x - y + P) % P;
			}
		}
	}
	if (op == -1)
		for (int i = 0; i < lim; ++i) a[i] = 1ll * a[i] * xx % P;
}

void mul(int *A, int *B, int *C) {
	static int a[_], b[_], ret[_];
	for (int i = 0; i < lim; ++i) a[i] = A[i], b[i] = B[i];
	NTT(a, 1); NTT(b, 1);
	for (int i = 0; i < lim; ++i) a[i] = 1ll * a[i] * b[i] % P;
	NTT(a, -1);
	for (int i = 0; i < M - 1; ++i) ret[i] = (a[i] + a[i + M - 1]) % P;
	for (int i = 0; i < M - 1; ++i) C[i] = ret[i];
}

int main() {
#ifndef ONLINE_JUDGE
	freopen("seq.in", "r", stdin);
	freopen("seq.out", "w", stdout);
#endif
	N = ty(), M = ty(), X = ty(), S = ty();
	g = root(M); long long t = 1;
	for (int i = 0; i < M - 1; ++i) {
		mp[t] = i;
		t = t * g % M;
	}
	for (int i = 1; i <= S; ++i) {
		int x = ty() % M;
		if (x) F[mp[x]] = 1;
	}
	H[0] = 1;
	int k = 0; while (lim <= M + M) lim <<= 1, ++k;
	for (int i = 0; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
	xx = ksm(lim, P - 2, P);
	for ( ; N; N >>= 1) {
		if (N & 1) mul(H, F, H);
		mul(F, F, F);
	}
	printf("%d\n", H[mp[X]]);
	return 0;
}
posted @ 2019-12-27 11:39  newbielyx  阅读(157)  评论(0编辑  收藏  举报