[SDOI2015]序列统计
[SDOI2015]序列统计
题意:
小C有一个集合\(S\),里面的元素都是小于\(m\)的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为\(n\)的数列,数列中的每个数都属于集合\(S\)。
小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数\(x\),求所有可以生成出的,且满足数列中所有数的乘积%\(m\)的值等于\(x\)的不同的数列的有多少个。
小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案对\(1004535809\)取模的值就可以了。
输入格式:
一行,四个整数\(n,m,x,∣S∣\)其中\(∣S∣\)为集合\(S\)中元素个数。
第二行,\(∣S∣\)个整数,表示集合\(S\)中的所有元素。
输出格式:
一行一个整数表示答案。
输入样例:
4 3 1 2
1 2
输出样例:
8
Solution:
首先定义数组\(f[i][j]\)表示生成到了第\(i\)个数,答案是\(j\)的方案数
\(n\)的大小为\(1e9\),首先可以ksm优化
\(f[i*2][j] = \sum_{a*b \mod m=j} f[i][a]*f[i][b]\)
\(8000\times8000\times\log_{1e9}\)的复杂度还是过不了题
可以发现后面那一坨有点像fft式子,但是条件是乘号。
想想只有对数可以吧加法和乘法联系在一起
这方面的知识点可以参考博主的博客 取模意义下的对数&生成元的查找
然后当我们把第二维全部换成对数时,可以得到式子
\(f[i*2][j] = \sum_{a+b \mod {m-1}=j}f[i][a]*f[i][b]\)
这个式子就直接ntt就行,中间为什么是\(\mod{m-1}\)参考博客即可
代码:
#include<bits/stdc++.h>
#define ll long long
#define R register
using namespace std;
template<class T>
void rea(T &x)
{
char ch=getchar();int f(0);x = 0;
while(!isdigit(ch)) {f|=ch=='-';ch=getchar();}
while(isdigit(ch)) {x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
x = f?-x:x;
}
int ksm(int x, int k, int mod)
{
int ret = 1;
while(k)
{
if(k&1) ret = 1ll*ret*x%mod;
x = 1ll*x*x%mod;
k >>= 1;
}
return ret;
}
int getroot(int mod)
{
int prime[10000], tot = 0;
int num = mod-1;
for(R int i = 2; i*i <= num; ++i)
if(num%i == 0)
{
prime[++tot] = i;
while(num%i == 0) num /= i;
}
if(num > 1) prime[++tot] = num;
num = mod-1;
for(R int i = 2; i <= num; ++i)
{
bool ban = 0;
for(R int j = 1; j <= tot; ++j)
if(ksm(i, num/prime[j], mod) == 1) { ban = 1; break; }
if(!ban) return i;
}
return false;
}
const int N = 10000, mod = 1004535809, G = 3, Gi = ksm(3, mod-2, mod);
int n, m, x, s, base[N<<2], ans[N<<2], pos[N<<2];
map<int, int>Log;
void prepos(int k)
{
int len = (1<<k);
for(R int i = 0; i < len; ++i)
pos[i] = (pos[i>>1]>>1)|((i&1)<<(k-1));
}
void NTT(int *a, int len, int flag)
{
for(R int i = 0; i < len; ++i) if(pos[i] > i) swap(a[pos[i]], a[i]);
for(R int mid = 1; mid < len; mid*=2)
{
int wx = ksm(flag==1?G:Gi, (mod-1)/(mid*2), mod);
for(R int i = 0; i < len; i += mid*2)
{
int w = 1;
for(R int j = i; j < i+mid; ++j)
{
int x = a[j], y = 1ll*a[j+mid]*w%mod;
a[j] = (x+y)%mod, a[j+mid] = (x-y+mod)%mod;
w = 1ll*w*wx%mod;
}
}
}
if(flag == -1)
{
int inv = ksm(len, mod-2, mod);
for(R int i = 0; i < len; ++i) a[i] = 1ll*a[i]*inv%mod;
}
}
void X(int *a, int *b, int len)
{
int A[N<<2], B[N<<2];
for(R int i = 0; i < len; ++i) A[i] = a[i], B[i] = b[i];
NTT(A, len, 1); NTT(B, len, 1);
for(R int i = 0; i < len; ++i) A[i] = 1ll*A[i]*B[i]%mod;
NTT(A, len, -1);
for(R int i = 0; i < m-1; ++i) A[i] = (A[i]+A[m+i-1])%mod;
for(R int i = 0; i < m-1; ++i) a[i] = A[i];
}
int main()
{
rea(n), rea(m), rea(x), rea(s);
int g = getroot(m); for(R int i = 0; i < m-1; ++i) Log[ksm(g, i, m)] = i;
for(R int i = 1; i <= s; ++i) {rea(g);if(g%m) base[Log[g%m]]++;}
int limit = 2, k = 1;
while(limit < m*2) limit <<= 1, k++;
prepos(k);
ans[0] = 1;
while(n)
{
if(n&1) X(ans, base, limit);
X(base, base, limit);
n >>= 1;
}
printf("%d\n", ans[Log[x]]);
return 0;
}