【洛谷3321_BZOJ3992】[SDOI2015]序列统计(原根_多项式)
题目:
分析:
一个转化思路比较神(典型?)的题……
一个比较显然的\(O(n^3)\)暴力是用\(f[i][j]\)表示选了\(i\)个数,当前积在模\(m\)意义下为\(j\)的方案数,每次转移枚举\(S\)的元素,即(\(k^{-1}\)表示\(k\)在模\(m\)意义下的逆元):
\[f[i][j]=\sum_{k\in S} f[i-1][jk^{-1}]
\]
事实上写的时候通常是从\(f[i][j]\)往\(f[i+1][jk]\)贡献
然后通过Orz题解发现那个乘法\(jk^{-1}\)非常地丑,如果能变成加法就好了qwq。注意到保证\(m\)是一个质数,所以是可以求原根的。原根的好处在于模\(m\)的意义下\(g^i(i\in [0,m-1))\)与\(a(a \in [1,m-1])\)一一对应。所以如果用原根的幂代替原数,原数的乘法就变成了指数的加法。即设\(f[i][j]\)表示选了\(i\)个数,积在模\(m\)意义下是\(g^j\)的方案数,则:
\[f[i][j]=\sum_{g^k \in S} f[i-1][j-k]
\]
上面那个东西像不像卷积?一点都不像看到\(j-k\)难道你就不想再来一个跟\(k\)有关的东西吗?于是定义一个函数\(h(x)\):
\[h(x)=\begin{cases}1&(g^x\in S) \\
0&(otherwise)\end{cases}\]
那么就成了:
\[f[i][j]=\sum_{k=0}^{m-2}h(k)f[i-1][j-k]
\]
设一个多项式\(A\),第\(i\)项系数是\(h(i)\),则一开始\(f[1]\)就是\(A\),所以答案就是\(A^n\)的\(x\)次项系数,写一个多项式快速幂即可。
代码:
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
#include <set>
using namespace std;
namespace zyt
{
template<typename T>
inline bool read(T &x)
{
char c;
bool f = false;
x = 0;
do
c = getchar();
while (c != EOF && c != '-' && !isdigit(c));
if (c == EOF)
return false;
if (c == '-')
f = true, c = getchar();
do
x = x * 10 + c - '0', c = getchar();
while (isdigit(c));
if (f)
x = -x;
return true;
}
template<typename T>
inline void write(T x)
{
static char buf[20];
char *pos = buf;
if (x < 0)
putchar('-'), x = -x;
do
*pos++ = x % 10 + '0';
while (x /= 10);
while (pos > buf)
putchar(*--pos);
}
typedef long long ll;
const int N = 8010, p = 1004535809;
int n, m;
set<int> s;
namespace Polynomial
{
const int LEN = N << 2;
int omega[LEN], winv[LEN], rev[LEN];
inline int power(int a, int b, const int p = ::zyt::p)
{
int ans = 1;
while (b)
{
if (b & 1)
ans = (ll)ans * a % p;
a = (ll)a * a % p;
b >>= 1;
}
return ans;
}
inline int inv(const int a, const int p = ::zyt::p)
{
return power(a, p - 2);
}
namespace Primitive_Root
{
pair<int, int>prime[N];
int cnt;
inline void get_prime(int n)
{
cnt = 0;
for (int i = 2; i * i <= n; i++)
{
if (n % i == 0)
{
prime[cnt++] = make_pair(i, 0);
while (n % i == 0)
++prime[cnt - 1].second, n /= i;
}
}
if (n > 1)
prime[cnt++] = make_pair(n, 1);
}
inline int get_g(const int n)
{
get_prime(n - 1);
for (int i = 2; i < n; i++)
{
bool flag = true;
for (int j = 0; j < cnt && flag; j++)
flag &= (power(i, (n - 1) / prime[j].first, n) != 1);
if (flag)
return i;
}
return -1;
}
}
void init(const int n, const int lg2)
{
static int g = 0;
if (!g)
g = Primitive_Root::get_g(p);
int w = power(g, (p - 1) / n), wi = inv(w);
omega[0] = winv[0] = 1;
for (int i = 1; i < n; i++)
{
omega[i] = (ll)omega[i - 1] * w % p;
winv[i] = (ll)winv[i - 1] * wi % p;
}
for (int i = 0; i < n; i++)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
}
void ntt(int *a, const int *w, const int n)
{
for (int i = 0; i < n; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int l = 1; l < n; l <<= 1)
for (int i = 0; i < n; i += (l << 1))
for (int k = 0; k < l; k++)
{
int tmp = (a[i + k] - (ll)a[i + l + k] * w[n / (l << 1) * k] % p + p) % p;
a[i + k] = (a[i + k] + (ll)a[i + l + k] * w[n / (l << 1) * k] % p) % p;
a[i + l + k] = tmp;
}
}
void mul(const int *a, const int *b, int *c, const int n)
{
static int x[LEN], y[LEN];
memcpy(x, a, sizeof(int[n]));
memcpy(y, b, sizeof(int[n]));
int m = 1, lg2 = 0;
while (m < (n << 1))
m <<= 1, ++lg2;
init(m, lg2);
memset(x + n, 0, sizeof(int[m - n]));
memset(y + n, 0, sizeof(int[m - n]));
ntt(x, omega, m), ntt(y, omega, m);
for (int i = 0; i < m; i++)
x[i] = (ll)x[i] * y[i] % p;
ntt(x, winv, m);
int invm = inv(m);
for (int i = 0; i < n; i++)
c[i] = (ll)(x[i] + x[i + n]) * invm % p;
}
void power(const int *a, int b, int *c, const int n)
{
static int x[N];
memcpy(x, a, sizeof(int[n]));
memset(c, 0, sizeof(int[n]));
c[0] = 1;
while (b)
{
if (b & 1)
mul(c, x, c, n);
mul(x, x, x, n);
b >>= 1;
}
}
}
int A[N];
int work()
{
using Polynomial::power;
using Polynomial::Primitive_Root::get_g;
int x, ssize;
read(n), read(m), read(x), read(ssize);
for (int i = 0; i < ssize; i++)
{
int a;
read(a);
s.insert(a);
}
int g = get_g(m), gx = -1;
for (int i = 0; i < m - 1; i++)
{
int tmp = power(g, i, m);
if (s.count(tmp))
A[i] = 1;
if (tmp == x)
gx = i;
}
power(A, n, A, m - 1);
write(A[gx]);
return 0;
}
}
int main()
{
return zyt::work();
}