题解 稀疏阶乘问题
即为统计 \(m\mid F(x)\) 的个数
有性质 \(m\mid F(x)\to m\mid F(x+m)\),所以按 \(x \bmod m\) 的余数分类处理
考虑 \(m\) 的每个质因子,限制变为 \(\prod p_i^{c_i}\mid F(x)\)
那么对每个 \(r\) 和每个 \(p_i^{c_i}\) 求出符合上述限制的最小 \(x\),再整体取 max 即可得到 \(m\mid F(x), x\equiv r\pmod m\) 的最小 \(x\)
然后问题即变为区间内 \(x\equiv r\pmod m\) 的 \(x\) 的个数
康康怎么对每个 \(r\) 和每个 \(p_i^{c_i}\) 求符合限制的最小 \(x\)
转为求 \(\prod\limits_{k=0}^{lim}(r-k^2)\equiv 0\pmod{p_i^{c_i}}\) 的最小 \(lim\)
由同余性质知 \(lim<p_i\times c_i\) 否则无解
于是枚举 \(k\in[0, p_i\times c_i)\),它能影响到 \(r\equiv k^2\pmod{p_i}\) 的 \(r\),更新其 \(lim\)
然后对每个 \(r\) 根据 \(lim\) 的定义(\(\lfloor\sqrt{x-1}\rfloor\))不难反推出最小的 \(x\)
复杂度 \(O(\sum \lfloor\frac{m}{p_i}\rfloor p_ic_i)=O(m\log m)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 1000010
#define ll long long
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline ll read() {
ll ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
ll l, r, m;
namespace force{
ll ans;
void solve() {
for (int i=l; i<=r; ++i) {
int lim=floor(sqrt(i-1));
ll tem=1;
for (ll k=0; k<=lim&&tem; ++k)
tem=tem*(i-k*k)%m;
if (!tem) ++ans;
}
cout<<ans<<endl;
}
}
namespace task{
int p[N], c[N], ans[N], rest[N], lim[N], top, sum;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a,b>>=1) if (b&1) ans=ans*a; return ans;}
int F(int x) {
int lim=floor(sqrt(x-1));
ll tem=1;
for (ll k=0; k<=lim&&tem; ++k)
tem=tem*(x-k*k)%m;
return tem;
}
inline int qcnt(int t, int r) {return r/m+(r%m>=t%m)-(t!=0);}
void solve() {
int tem=m;
for (int i=2; i<=m; ++i) if (tem%i==0) {
p[++top]=i;
do {tem/=i; ++c[top];} while (tem%i==0);
}
// cout<<"p: "; for (int i=1; i<=top; ++i) cout<<p[i]<<' '; cout<<endl;
// cout<<"c: "; for (int i=1; i<=top; ++i) cout<<c[i]<<' '; cout<<endl;
for (int i=1; i<=top; ++i) {
// cout<<"i: "<<i<<endl;
int mod=qpow(p[i], c[i]);
for (int j=0; j<m; ++j) lim[j]=INF, rest[j]=1;
for (int k=0; k<p[i]*c[i]; ++k) {
int r=k*k%p[i];
for (int t=r; t<m; t+=p[i]) {
rest[t]=rest[t]*(t-k*k)%mod;
if (!rest[t]) lim[t]=min(lim[t], k);
}
}
// cout<<"lim: "; for (int j=0; j<m; ++j) cout<<lim[j]<<' '; cout<<endl;
for (int t=0; t<m; ++t) {
int x=lim[t]*lim[t]+1;
ans[t]=max(ans[t], x);
}
// cout<<"ans: "; for (int j=0; j<m; ++j) cout<<ans[j]<<' '; cout<<endl;
}
for (int t=0; t<m; ++t) if (ans[t]<=r)
sum+=qcnt(t, r)-qcnt(t, max(l, ans[t])-1);
printf("%lld\n", sum);
}
}
signed main()
{
freopen("sfac.in", "r", stdin);
freopen("sfac.out", "w", stdout);
l=read(); r=read(); m=read();
// force::solve();
task::solve();
return 0;
}