杜教筛入门
当学 Min 25 的一个前置知识。
算法内容。
定义 \(S(n)=\sum_{i=1}^nf(i)\)。对于一个函数 \(g\),有:
所以如果存在函数 \(g\),满足:
- \(g(1)\neq 0\)
- \(\sum_{i=1}^n(f * g)(i)\) 可以快速计算
- \(g(i)\) 可以快速计算
可以通过记忆化搜索+数论分块快速计算 \(S(n)\)。可以用 unordered_map
存储结果。
直接计算复杂度为 \(O(n^{\frac{3}{4}})\)。更好的计算是预处理前 \(O(n^{\frac{2}{3}})\) 的 \(S\) 值,可以做到 \(O(n^{\frac{2}{3}})\) 的复杂度。
具体证明可以参考 link。证明本质是积分,简单的。
常用构建函数 \(g\) 技巧:
- \(\sum_{d|n}^n\mu(d)=[n=1]\)
- \(\sum_{d|n}^nu(d)\frac{n}{d}=\varphi(d)\)
- \(\sum_{d|n}^n\varphi(d)=n\)
- \(i^k·(\frac{n}{i})^k=n^k\)
例子:
\(S(n)=\sum_{i=1}^n\mu(i)\)。
令 \(g(n)=1\),也即常函数。
则 \(\sum_{i=1}^n(g * \mu)(i)=[n\ge 1]\)
则
\(S(n)=\sum_{i=1}^n\varphi(i)\)
令 \(g(n)=1\)
则 \(\sum_{i=1}^n(g * \varphi)(i)=\frac{n(n+1)}{2}\)。
则
实操
计算 \(\sum_{i=1}^n\sum_{j=1}^nij\gcd(i,j)\)
先化简:
令 \(h(n)=\frac{n^2(n+1)^2}{4}\),\(f(n)=n^2\varphi(n)\)。
对于上式,后者可以数论分块,问题化为求解 \(f\) 的前缀和。
令 \(g(n)=n^2\)
则 \(\sum_{i=1}^n(f * g)(i=\sum_{i=1}^n\sum_{d|i}d^2\varphi(d) * \frac{i^2}{d^2}=\sum_{i=1}^ni^3=\frac{n^2(n+1)^2}{4}\)
则 \(g(1)=1\)
则 \(g(i)\) 可以快速计算。
所以 \(S_f(n)=\frac{n^2(n+1)^2}{4}-\sum_{i=2}^ni^2S_f(\lfloor\frac{n}{i}\rfloor)\)
数论分块即可。
模版代码一份。
#include<bits/stdc++.h>
using namespace std;
#define N 1050500
#define int long long
const int it=1e6+7;
int v[N],pri[N],tot,p,phi[N],n,m,s[N],inv2,inv6,inv4;
int power(int a,int b){
int ans=1;
while(b){
if(b&1)ans=ans*a%p;
a=a*a%p;b>>=1;
}
return ans;
}
unordered_map<int,int>sit;
void init(){
phi[1]=1;
for(int i=2;i<it;i++){
if(!v[i]){
pri[++tot]=i;phi[i]=i-1;
}
for(int j=1;j<=tot&&i*pri[j]<it;++j){
v[pri[j]*i]=1;
if(i%pri[j]==0){
phi[pri[j]*i]=pri[j]*phi[i];break;
}
phi[pri[j]*i]=phi[pri[j]]*phi[i];
}
}
for(int i=1;i<it;++i)s[i]=(s[i-1]+i*i%p*phi[i]%p)%p;
}
int h(int n){
n%=p;
return n*n%p*(n+1)%p*(n+1)%p*inv4%p;
}
int pfs(int n){
n%=p;
return n*(n+1)%p*(n+n+1)%p*inv6%p;
}
int calc(int n){
if(n<it)return s[n];
if(sit[n]!=0)return sit[n];
int res=h(n);int lst=1,cur=0;
for(int l=2,r;l<=n;l=r+1){
r=min(n,n/(n/l));cur=pfs(r);
res=(res+p-(cur-lst)%p*calc(n/l)%p)%p;lst=cur;
}res=(res%p+p)%p;
return sit[n]=res;
}
signed main(){
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin>>p;cin>>n;inv2=power(2,p-2);inv4=power(4,p-2);inv6=power(6,p-2);
init();
int ans=0,lst=0;
for(int l=1,r;l<=n;l=r+1){
r=min(n,n/(n/l));int cur=calc(r);
ans+=(cur-lst)*h(n/l)%p;ans%=p;lst=cur;
}
ans=(ans%p+p)%p;
cout<<ans<<"\n";
}