洛谷P3321「序列统计」
题目描述
小 \(C\) 有一个集合 \(S\),里面的元素都是小于 \(m\) 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 \(n\) 的数列,数列中的每个数都属于集合 \(S\) 。
小 \(C\) 用这个生成器生成了许多这样的数列。但是小 \(C\) 有一个问题需要你的帮助:给定整数 \(x\),求所有可以生成出的,且满足数列中所有数的乘积 \(\mod m\) 的值等于 \(x\) 的不同的数列的有多少个。
小C认为,两个数列 \(A\) 和 \(B\) 不同,当且仅当 \(\exists i \;\text{s.t.} A_i\neq B_i\)。另外,小 \(C\) 认为这个问题的答案可能很大,因此他只需要你帮助他求出答案对 \(1004535809\) 取模的值就可以了。
输入格式
一行,四个整数,\(n,m,x,\left | S \right |\),其中 \(\left | S \right |\) 为集合 \(S\) 中元素个数。
第二行,\(\left | S \right |\) 个整数,表示集合 \(S\) 中的所有元素。
输出格式
一行一个整数表示答案。
输入输出样例
输入
4 3 1 2
1 2
输出
8
说明/提示
【样例说明】
可以生成的满足要求的不同的数列有
\((1,1,1,1),\;(1,1,2,2),\;(1,2,1,2),\;(1,2,2,1),\;(2,1,1,2),\;(2,1,2,1),\;(2,2,1,1),\;(2,2,2,2)\)。
【数据规模和约定】
对于 \(10\%\) 的数据,\(1\leq n\leq 1000\);
对于 \(30\%\) 的数据,\(3\leq m\leq 100\);
对于 \(60\%\) 的数据,\(3\leq m\leq 800\);
对于 \(100\%\) 的数据,\(1\leq n\leq 10^9,3\leq m\leq 8000,1\leq x <m\)。
\(m\) 为质数,输入数据保证集合 \(S\) 中元素不重复。
题解
\(m\) 的值域很小,考虑用Triple一题的思路,将集合中的数放到多项式上,问题就可以转化为一个式子:
这已经很像我们多项式乘法的式子了,但是唯一的不同是,这里的 \(i\) 和 \(j\) 是相乘的
考虑如何将乘法转化成加法?
可以利用高中数学里的对数
对数有个很好的性质:
不妨将 \(i\) 和 \(j\) 都用一个数来取对数,得到 \(log^i\) 和 \(log^j\)
在询问 \(x\) 地方的值的时候,相当于是询问 \(log^x\) 地方的值
现在问题转化为
问题又来了,要选取哪个底数来对所有的值取对数呢?
利用原根
如果连原根都不知道是什么的小朋友,可以去百度百科初步了解一下,不会原根,你怎么学的 \(NTT\)?
我们想把所有值取 \(log\) 要保证什么?
比如说我们取的底数是 \(g\)
我们需要 \(1\sim m-1\) 的 \(log_g^i\mod log_g^m\) 互不相同
即 \(1\sim m-1\) 的 \(g^i\mod m\) 互不相同
这不就是原根的第二个性质嘛
\(1\sim m-1\) 的 \(g^i\) 正好一一对应了 \(1\sim m - 1\) 的所有值
我们就可以把原题中集合 \(S\) 的每个值去取对数了
然后我们就可以得到一个 \(1\sim m-1\) 的多项式,由于没有常数项难以处理,直接变成 \(0\sim m-2\) 的多项式来处理
对于取模,就很好处理了
两个长度为 \(m-2\) 的多项式相乘,对于得到的多项式的 \(m-1\) 次项以后,同时也对答案造成了贡献,次数模上模数加上贡献即可
我们要选 \(n\) 个数,而且没有像Triple一样“不能选重复的限制”,不用什么乱七八糟的容斥,所以直接 \(f(x)^k\) 即可
我们可以不用像多项式快速幂那么麻烦的快速幂,而且 \(a_0\) 不一定为 \(1\),也用不了
可以像普通实数快速幂一样,每次只留前 \(m-2\) 位,只不过效率是 \(nlog^{2n}\)
代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const int maxn = 2e5 + 50, INF = 0x3f3f3f3f, mod = 1004535809, inv3 = 334845270;
inline int read () {
register int x = 0, w = 1;
register char ch = getchar ();
for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0';
return x * w;
}
inline void write (register int x) {
if (x / 10) write (x / 10);
putchar (x % 10 + '0');
}
int n, m, g, X, s, len = 1, bit;
bool vis[maxn];
int res[maxn], tmp[maxn], ans[maxn];
int f[maxn], id[maxn], rev[maxn];
inline int gqpow (register int a, register int b, register int ans = 1) {
for (; b; b >>= 1, a = 1ll * a * a % m)
if (b & 1) ans = 1ll * ans * a % m;
return ans;
}
inline int Get_g (register int m) {
for (register int i = 0; i < m; i ++) {
memset (vis, 0, sizeof 4 * m);
for (register int k = 1, tmp; k <= m - 1; k ++) {
tmp = gqpow (i, k);
if (vis[tmp]) goto end;
else vis[tmp] = 1;
}
return i;
end:;
}
return -1;
}
inline int qpow (register int a, register int b, register int ans = 1) {
for (; b; b >>= 1, a = 1ll * a * a % mod)
if (b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline void NTT (register int len, register int * a, register int opt) {
for (register int i = 1; i < len; i ++) if (i < rev[i]) swap (a[i], a[rev[i]]);
for (register int d = 1; d < len; d <<= 1) {
register int w1 = qpow (opt, (mod - 1) / (d << 1));
for (register int i = 0; i < len; i += d << 1) {
register int w = 1;
for (register int j = 0; j < d; j ++, w = 1ll * w * w1 % mod) {
register int x = a[i + j], y = 1ll * w * a[i + j + d] % mod;
a[i + j] = (x + y) % mod, a[i + j + d] = (x - y + mod) % mod;
}
}
}
}
inline void Calc (register int * a, register int * b) {
memset (res, 0, 4 * len), memset (tmp, 0, 4 * len);
for (register int i = 0; i < m; i ++) res[i] = a[i], tmp[i] = b[i], a[i] = 0;
NTT (len, res, 3), NTT (len, tmp, 3);
for (register int i = 0; i < len; i ++) res[i] = 1ll * res[i] * tmp[i] % mod;
NTT (len, res, inv3);
register int inv = qpow (len, mod - 2);
for (register int i = 0; i < len; i ++) res[i] = 1ll * res[i] * inv % mod, a[i % (m - 1)] = (a[i % (m - 1)] + res[i]) % mod;
}
int main () {
n = read(), m = read(), X = read(), s = read(), g = Get_g (m);
for (register int i = 0; i <= m - 2; i ++) id[gqpow (g, i)] = i;
while (len < m << 1) len <<= 1, bit ++;
for (register int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
for (register int i = 1; i <= s; i ++) {
register int x = read();
if (x) f[id[x]] ++;
}
ans[0] = 1;
while (n) {
if (n & 1) Calc (ans, f);
n >>= 1, Calc (f, f);
}
printf ("%d\n", ans[id[X]]);
return 0;
}