【洛谷P3321】序列统计

题目

题目链接:https://www.luogu.com.cn/problem/P3321
小C有一个集合 \(S\),里面的元素都是小于 \(m\) 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 \(n\) 的数列,数列中的每个数都属于集合 \(S\)
小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数 \(x\),求所有可以生成出的,且满足数列中所有数的乘积 \(\bmod \ m\) 的值等于 \(x\) 的不同的数列的有多少个。
小C认为,两个数列 \(A\)\(B\) 不同,当且仅当 \(\exists i \text{ s.t. } A_i \neq B_i\)。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案对 \(1004535809\) 取模的值就可以了。
\(n\leq 10^9,m\leq 8000\)

思路

之前在 GMOJ 这道题时限开 \(5s\) 被我 \(O(m^2\log n)\) 艹过去了。
首先 \(60\)pts 的倍增 dp 就是设 \(f[i][j]\) 表示选了 \(2^i\) 个数,乘积 \(\bmod p\) 之后的结果为 \(j\) 的方案数。
转移为

\[f[k][l]=\sum^{}_{i\times j\bmod p=l}f[k-1][i]\times f[k-1][j] \]

然后二进制拆分即可。
如果这个乘号是加号的话,我们就可以 NTT 优化了。
考虑如何把乘号变为加号,因为 \(\log_ab+\log_ac=\log_a(bc)\),所以可以用对数进行转化。
但是我们需要保证转化后对于任意两个 \(x,y\in [1,m)\)\(x\neq y\),都有 \(\log_a x\neq \log_a y\),由于 \(m\) 是质数,所以我们取 \(m\) 的原根即可。
接下来就和 \(60\)pts 的做法一样了。将每一数转化为对数之后扔进一个多项式里,然后倍增计算即可。
时间复杂度 \(O(m\log n\log m)\)

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=18010,MOD=1004535809;
int n,m,s,l,G,lim,a[N],rev[N];
ll f[N],g[N],h[N];

ll fpow(ll x,ll k,ll mod=(ll)MOD)
{
	ll ans=1;
	for (;k;k>>=1,x=x*x%mod)
		if (k&1) ans=ans*x%mod;
	return ans;
}

int findg(int p)
{
	vector<int> d;
	for (int i=2;i<=p-1;i++)
		if ((p-1)%i==0) d.push_back(i);
	for (int i=1;i<=p;i++)
	{
		bool flag=1;
		for (int j=0;j<d.size();j++)
			if (fpow(i,(p-1)/d[j],p)==1) { flag=0; break; }
		if (flag) return i;
	}
}

void NTT(ll *f,bool tag)
{
	for (int i=0;i<lim;i++)
		if (i<rev[i]) swap(f[i],f[rev[i]]);
	for (int k=1;k<lim;k<<=1)
	{
		ll tmp=fpow((tag?3:334845270),(MOD-1)/(k<<1));
		for (int i=0;i<lim;i+=(k<<1))
		{
			ll w=1;
			for (int j=0;j<k;j++,w=w*tmp%MOD)
			{
				ll x=f[i+j],y=w*f[i+j+k]%MOD;
				f[i+j]=(x+y)%MOD; f[i+j+k]=(x-y)%MOD;
			}
		}
	}
}

int main()
{
	scanf("%d%d%d%d",&n,&m,&s,&l);  // 十分优雅的读入
	G=findg(m);
	for (int i=1;i<m;i++)
		a[fpow(G,i,m)]=i;
	for (int i=1,x;i<=l;i++)
	{
		scanf("%d",&x);
		if (x) f[a[x]]++;
	}
	g[0]=lim=1;
	while (lim<=2*m) lim<<=1;
	for (int i=0;i<lim;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)?(lim>>1):0);
	ll inv=fpow(lim,MOD-2);
	for (int k=0;k<=30;k++)
	{
		if (n&(1<<k))
		{
			memcpy(h,f,sizeof(f));
			NTT(g,1); NTT(h,1);
			for (int i=0;i<lim;i++) g[i]=g[i]*h[i]%MOD;
			NTT(g,0);
			for (int i=1;i<m;i++)
				g[i]=(g[i]+g[i+m-1])*inv%MOD;
			for (int i=m;i<lim;i++) g[i]=0;
		}
		NTT(f,1);
		for (int i=0;i<lim;i++) f[i]=f[i]*f[i]%MOD;
		NTT(f,0);
		for (int i=1;i<m;i++)
			f[i]=((f[i]+f[i+m-1])*inv%MOD+MOD)%MOD;
		for (int i=m;i<lim;i++) f[i]=0;
	}
	printf("%lld",(g[a[s]]%MOD+MOD)%MOD);
	return 0;
}
posted @ 2021-01-07 21:00  stoorz  阅读(72)  评论(0编辑  收藏  举报