洛谷P4466 [国家集训队] 和与积
所求即为:
\[\large \sum_{a=1}^n\sum_{b=a+1}^n\left[ a+b \mid ab \right]
\]
设 \(\gcd(a,b)=d,a=id,b=jd\),得:
\[\large\begin{aligned}
d(i+j)&\mid ijd^2\\
(i+j)&\mid ijd\\
\end{aligned}
\]
发现因为 \(\gcd(i,j)=1\),所以 \((i+j)\not \mid ij\),这是因为有:
\[\large \gcd(i+j,i)=\gcd(i+j,j)=1
\]
因此得 \((i+j)\mid d\)。原式变为:
\[\large\begin{aligned}
&\sum_{i=1}^{\sqrt n}\sum_{j=i+1}^{\sqrt n}\left[ \gcd(i,j)=1 \right]\left\lfloor \frac{n}{j(i+j)} \right\rfloor \\
=&\sum_{i=1}^{\sqrt n}\sum_{j=i+1}^{\sqrt n}\sum_{d\mid i \and d \mid j}\mu(d)\left\lfloor \frac{n}{j(i+j)} \right\rfloor \\
=&\sum_{d=1}^{\sqrt n}\mu(d)\sum_{i=1}^{\sqrt n}\left[ d\mid i \right]\sum_{j=i+1}^{\sqrt n}\left[ d\mid j \right]\left\lfloor \frac{n}{j(i+j)} \right\rfloor \\
=&\sum_{d=1}^{\sqrt n}\mu(d)\sum_{i=1}^{\left\lfloor\frac{\sqrt n}{d}\right\rfloor}\sum_{j=i+1}^{\left\lfloor\frac{\sqrt n}{d}\right\rfloor}\left\lfloor \frac{n}{d^2j(i+j)} \right\rfloor \\
=&\sum_{d=1}^{\sqrt n}\mu(d)\sum_{i=2}^{\left\lfloor\frac{\sqrt n}{d}\right\rfloor}\sum_{j=i+1}^{2i-1}\left\lfloor \frac{\left\lfloor \frac{n}{d^2i} \right\rfloor}{j} \right\rfloor \\
\end{aligned}
\]
最后一步是分别枚举 \(j\) 和 \(i+j\)。枚举 \(d,j\) 后,\(\left\lfloor \frac{n}{d^2i} \right\rfloor\) 就为定值了,然后就可以数论分块了。
大致分析一下复杂度,得总枚举次数为:
\[\large \sum_{i=1}^{\sqrt n}\frac{\sqrt n}{i}\sqrt{\frac{\sqrt n}{i}}=n^{\frac{3}{4}}\sum_{i=1}^{\sqrt n}\frac{1}{i^{\frac{3}{2}}}
\]
后一项为黎曼函数 \(\zeta(x)\),当 \(x=\frac{3}{2}\) 时,其取值约为 \(2.6\),因此复杂度为 \(O(n^{\frac{3}{4}})\)。
#include<bits/stdc++.h>
#define maxn 47350
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,m,tot;
ll ans;
int p[maxn],mu[maxn];
bool tag[maxn];
void init()
{
mu[1]=1;
for(int i=2;i<=m;++i)
{
if(!tag[i]) p[++tot]=i,mu[i]=-1;
for(int j=1;j<=tot;++j)
{
int k=i*p[j];
if(k>m) break;
tag[k]=true;
if(i%p[j]) mu[k]=mu[i]*mu[p[j]];
else
{
mu[k]=0;
break;
}
}
}
}
ll calc(ll d)
{
ll v=0;
for(int i=2;i<=m/d;++i)
{
ll val=n/(d*d*i);
if(!val) continue;
for(int l=i+1,r;l<=2*i-1;l=r+1)
{
if(val/l==0) break;
r=min(val/(val/l),(ll)2*i-1),v+=val/l*(r-l+1);
}
}
return v;
}
int main()
{
read(n),m=sqrt(n),init();
for(int i=1;i<=m;++i)
if(mu[i])
ans+=mu[i]*calc(i);
printf("%lld",ans);
return 0;
}