题解 摆
是原题的加强版,我就一并写在那题题解里了
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 20000010
#define ll long long
// #define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline ll read() {
ll ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
ll n, c;
const ll mod=998244353;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace task1{
ll mp1[N], mp2[N], val, sqr;
// unordered_map<int, ll> mp;
ll f(int n) {
// if (mp.find(n)!=mp.end()) return mp[n];
if (n<=sqr) {if (mp1[n]) return mp1[n];}
else {if (mp2[::n/n]) return mp2[::n/n];}
ll ans=0;
for (int l=2,r; l<=n; l=r+1) {
r=n/(n/l);
ans=(ans+(r-l+1)*f(n/l))%mod;
}
ans=(1+val*ans)%mod;
if (n<=sqr) mp1[n]=ans;
else mp2[::n/n]=ans;
return ans;
}
void solve() {
sqr=sqrt(n);
val=1ll*c*qpow(1-c, mod-2)%mod;
printf("%lld\n", (qpow(1-c, n)*((f(n)+f(n)*val)%mod)%mod+mod)%mod);
}
}
namespace task2{
bool npri[N];
const ll base=13131;
unordered_map<ll, ll> mp;
int pri[N], f[N], lowc[N], pcnt;
ll mp1[N], mp2[N], pw[N], h[N], val, sqr;
ll F(ll n) {
if (n<N) return f[n];
if (n<=sqr) {if (mp1[n]) return mp1[n];}
else {if (mp2[::n/n]) return mp2[::n/n];}
ll ans=0;
for (ll l=2,r; l<=n; l=r+1) {
r=n/(n/l);
ans=(ans+(r-l+1)%mod*F(n/l))%mod;
}
ans=(1+val*ans)%mod;
if (n<=sqr) mp1[n]=ans;
else mp2[::n/n]=ans;
return ans;
}
ll force_f(ll n) {
ll i, ans=f[1];
for (i=2; i*i<n; ++i) if (!(n%i)) ans=(ans+f[i]+f[n/i])%mod;
if (i*i==n) ans=(ans+f[i])%mod;
return ans*val%mod;
}
void solve() {
sqr=sqrt(n);
val=1ll*c*qpow(1-c, mod-2)%mod;
f[1]=pw[0]=1;
for (int i=1; i<=50; ++i) pw[i]=pw[i-1]*base%mod;
for (int i=2; i<N; ++i) {
if (!npri[i]) pri[++pcnt]=i, h[i]=pw[lowc[i]=1], f[i]=val;
for (int j=1,x; j<=pcnt&&1ll*i*pri[j]<N; ++j) {
npri[x=i*pri[j]]=1;
if (!(i%pri[j])) {
lowc[x]=lowc[i]+1;
h[x]=((h[i]-pw[lowc[i]]+pw[lowc[x]])%mod+mod)%mod;
break;
}
else lowc[x]=1, h[x]=(h[i]+pw[1])%mod;
}
}
for (int i=2; i<N; ++i)
if (mp.find(h[i])!=mp.end()) f[i]=mp[h[i]];
else f[i]=mp[h[i]]=force_f(i);
for (int i=2; i<N; ++i) f[i]=(f[i]+f[i-1])%mod;
printf("%lld\n", (qpow(1-c, n)*((F(n)+F(n)*val)%mod)%mod+mod)%mod);
}
}
signed main()
{
freopen("bigben.in", "r", stdin);
freopen("bigben.out", "w", stdout);
n=read(); c=read();
if (c==1 && n<=2) puts("1");
// else task1::solve();
else task2::solve();
return 0;
}