洛谷P3321「序列统计」

题目描述

\(C\) 有一个集合 \(S\),里面的元素都是小于 \(m\) 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 \(n\) 的数列,数列中的每个数都属于集合 \(S\)

\(C\) 用这个生成器生成了许多这样的数列。但是小 \(C\) 有一个问题需要你的帮助:给定整数 \(x\),求所有可以生成出的,且满足数列中所有数的乘积  \(\mod  m\) 的值等于 \(x\) 的不同的数列的有多少个。

小C认为,两个数列 \(A\)\(B\) 不同,当且仅当 \(\exists i \;\text{s.t.} A_i\neq B_i\)。另外,小 \(C\) 认为这个问题的答案可能很大,因此他只需要你帮助他求出答案对 \(1004535809\) 取模的值就可以了。

输入格式

一行,四个整数,\(n,m,x,\left | S \right |\),其中 \(\left | S \right |\) 为集合 \(S\) 中元素个数。

第二行,\(\left | S \right |\) 个整数,表示集合 \(S\) 中的所有元素。

输出格式

一行一个整数表示答案。

输入输出样例

输入

4 3 1 2
1 2

输出

8

说明/提示

【样例说明】

可以生成的满足要求的不同的数列有

\((1,1,1,1),\;(1,1,2,2),\;(1,2,1,2),\;(1,2,2,1),\;(2,1,1,2),\;(2,1,2,1),\;(2,2,1,1),\;(2,2,2,2)\)

【数据规模和约定】

对于 \(10\%\) 的数据,\(1\leq n\leq 1000\)

对于 \(30\%\) 的数据,\(3\leq m\leq 100\)

对于 \(60\%\) 的数据,\(3\leq m\leq 800\)

对于 \(100\%\) 的数据,\(1\leq n\leq 10^9,3\leq m\leq 8000,1\leq x <m\)

\(m\) 为质数,输入数据保证集合 \(S\) 中元素不重复。

题解

\(m\) 的值域很小,考虑用Triple一题的思路,将集合中的数放到多项式上,问题就可以转化为一个式子:

\[\sum_{ij\equiv x\mod m}a_ia_j \]

这已经很像我们多项式乘法的式子了,但是唯一的不同是,这里的 \(i\)\(j\) 是相乘的

\[\sum_{i+j\equiv x\mod m}a_ia_j \]

考虑如何将乘法转化成加法?

可以利用高中数学里的对数

对数有个很好的性质:

\[log^{ab}=log^{a}+log^{b} \]

不妨将 \(i\)\(j\) 都用一个数来取对数,得到 \(log^i\)\(log^j\)

在询问 \(x\) 地方的值的时候,相当于是询问 \(log^x\) 地方的值

现在问题转化为

\[\sum_{log^i+log^j\equiv log^x\mod log^m}a_ia_j \]

问题又来了,要选取哪个底数来对所有的值取对数呢?

利用原根

如果连原根都不知道是什么的小朋友,可以去百度百科初步了解一下,不会原根,你怎么学的 \(NTT\)

我们想把所有值取 \(log\) 要保证什么?

比如说我们取的底数是 \(g\)

我们需要 \(1\sim m-1\)\(log_g^i\mod log_g^m\) 互不相同

\(1\sim m-1\)\(g^i\mod m\) 互不相同

这不就是原根的第二个性质嘛

\(1\sim m-1\)\(g^i\) 正好一一对应了 \(1\sim m - 1\) 的所有值

我们就可以把原题中集合 \(S\) 的每个值去取对数了

然后我们就可以得到一个 \(1\sim m-1\) 的多项式,由于没有常数项难以处理,直接变成 \(0\sim m-2\) 的多项式来处理

对于取模,就很好处理了

两个长度为 \(m-2\) 的多项式相乘,对于得到的多项式的 \(m-1\) 次项以后,同时也对答案造成了贡献,次数模上模数加上贡献即可

我们要选 \(n\) 个数,而且没有像Triple一样“不能选重复的限制”,不用什么乱七八糟的容斥,所以直接 \(f(x)^k\) 即可

我们可以不用像多项式快速幂那么麻烦的快速幂,而且 \(a_0\) 不一定为 \(1\),也用不了

可以像普通实数快速幂一样,每次只留前 \(m-2\) 位,只不过效率是 \(nlog^{2n}\)

代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

typedef long long ll;
typedef unsigned long long ull;

using namespace std;

const int maxn = 2e5 + 50, INF = 0x3f3f3f3f, mod = 1004535809, inv3 = 334845270;

inline int read () {
	register int x = 0, w = 1;
	register char ch = getchar ();
	for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
	for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0';
	return x * w;
}

inline void write (register int x) {
	if (x / 10) write (x / 10);
	putchar (x % 10 + '0');
}

int n, m, g, X, s, len = 1, bit;
bool vis[maxn];
int res[maxn], tmp[maxn], ans[maxn];
int f[maxn], id[maxn], rev[maxn];

inline int gqpow (register int a, register int b, register int ans = 1) {
	for (; b; b >>= 1, a = 1ll * a * a % m) 
		if (b & 1) ans = 1ll * ans * a % m;
	return ans;
}

inline int Get_g (register int m) {
	for (register int i = 0; i < m; i ++) {
		memset (vis, 0, sizeof 4 * m);
		for (register int k = 1, tmp; k <= m - 1; k ++) {
			tmp = gqpow (i, k);
			if (vis[tmp]) goto end;
			else vis[tmp] = 1;
		}
		return i;
		end:;
	}
	return -1;
}

inline int qpow (register int a, register int b, register int ans = 1) {
	for (; b; b >>= 1, a = 1ll * a * a % mod) 
		if (b & 1) ans = 1ll * ans * a % mod;
	return ans;
}

inline void NTT (register int len, register int * a, register int opt) {
	for (register int i = 1; i < len; i ++) if (i < rev[i]) swap (a[i], a[rev[i]]);
	for (register int d = 1; d < len; d <<= 1) {
		register int w1 = qpow (opt, (mod - 1) / (d << 1));
		for (register int i = 0; i < len; i += d << 1) {
			register int w = 1;
			for (register int j = 0; j < d; j ++, w = 1ll * w * w1 % mod) {
				register int x = a[i + j], y = 1ll * w * a[i + j + d] % mod;
				a[i + j] = (x + y) % mod, a[i + j + d] = (x - y + mod) % mod;
			}
		}
	}
}

inline void Calc (register int * a, register int * b) {
	memset (res, 0, 4 * len), memset (tmp, 0, 4 * len);
	for (register int i = 0; i < m; i ++) res[i] = a[i], tmp[i] = b[i], a[i] = 0;
	NTT (len, res, 3), NTT (len, tmp, 3);
	for (register int i = 0; i < len; i ++) res[i] = 1ll * res[i] * tmp[i] % mod;
	NTT (len, res, inv3);
	register int inv = qpow (len, mod - 2);
	for (register int i = 0; i < len; i ++) res[i] = 1ll * res[i] * inv % mod, a[i % (m - 1)] = (a[i % (m - 1)] + res[i]) % mod;
}

int main () {
	n = read(), m = read(), X = read(), s = read(), g = Get_g (m);
	for (register int i = 0; i <= m - 2; i ++) id[gqpow (g, i)] = i;
	while (len < m << 1) len <<= 1, bit ++;
	for (register int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
	for (register int i = 1; i <= s; i ++) {
		register int x = read();
		if (x) f[id[x]] ++;
	}
	ans[0] = 1;
	while (n) {
		if (n & 1) Calc (ans, f);
		n >>= 1, Calc (f, f);
	}
	printf ("%d\n", ans[id[X]]);
	return 0;
}
posted @ 2020-12-31 17:12  Rubyonlу  阅读(111)  评论(0编辑  收藏  举报