bzoj3512-DZY Loves Math IV
题意
求
其中 \(\varphi\) 为欧拉函数。\(n\le 10^5,m\le 10^9\) 。
分析
简单的题目,感觉非常妙啊!这题学到了很多东西。
观察到 \(n\) 比较小,可以枚举 \(n\) ,转化为求 \(S(n,m)=\sum _{i=1}^m\varphi (ni)\) 。由于 \(\varphi\) 函数的性质,若 \(n=\prod _{i=1}^qp_i^{c_i}\) ,那么有 \(\varphi (n)=\varphi (\prod _{j=1}^qp_j)\prod _{i=1}^q p_i^{c_i-1}\) ,即令 \(n\) 的每个质因数只留下一次,剩下的直接乘上去即可。用 \(y\) 表示 \(n\) 剩下的因数,\(w\) 表示 \(n\) 的所有质因数的一次的乘积,即 \(yw=n\) ,那么有
中间绝妙地运用了 \(n=\sum _{d|n}\varphi (d)\) 代换,由于除去了gcd,所以后面与前面是互质的,直接把第一项和第三项合并即可。
最后我们得到了一个递归的形式。当 \(n=1\) 的时候,\(S(n,m)=\sum _{i=1}^m\varphi (i)\) ,可以用杜教筛的方法在一次 \(O(m^\frac{2}{3})\) 的时间复杂度内求出,在这样做的同时,我们还记忆化了所有 \(\varphi(\lfloor \frac{m}{x}\rfloor)\) 的值,所以这个复杂度只需要计算一次。\(S(n,m)\) 从第一层开始就只有 \(n\times \sqrt m\) 种取值,每一次上界是 \(\sqrt n\) ,所以总复杂度为 \(O(n(\sqrt m+\sqrt n)+m^{\frac{2}{3}})\) 。
很妙的题啊!!
代码
#include<bits/stdc++.h>
#include<tr1/unordered_map>
using namespace std;
using namespace std::tr1;
typedef long long giant;
const int maxn=1e6+1;
const int then=1e5+1;
const int q=1e9+7;
inline int Plus(int x,int y) {return (static_cast<giant>(x)+static_cast<giant>(y))%q;}
inline int Sub(int x,int y) {return Plus(x,q-y);}
inline int Multi(int x,int y) {return static_cast<giant>(x)*static_cast<giant>(y)%q;}
unordered_map<int,int> varphi,s[then];
int p[maxn],ps=0,varphi[maxn],from[maxn];
bool np[maxn];
int get(int n) {
int &ret=varphi[n];
if (ret) return ret; else ret=((giant)n*(n+1)/2)%q;
for (int i=2,j;i<=n;i=j+1) {
j=n/(n/i);
ret=Sub(ret,Multi(get(n/i),j-i+1));
}
return ret;
}
int S(int n,int m) {
if (n==1) return get(m);
if (m==1) return varphi[n];
if (!m) return 0;
if (s[n][m]) return s[n][m];
int w=1,y=1,&ret=s[n][m]=0;
vector<int> div;
div.clear();
while (n>1) {
int x=n/from[n];
w*=x,n=from[n];
div.push_back(x);
while (n%x==0) y*=x,n/=x;
}
int s=(1<<(int)div.size());
for (int i=0;i<s;++i) {
int d=1;
for (int j=0;j<div.size();++j) if ((i>>j)&1) d*=div[j];
ret=Plus(ret,Multi(varphi[w/d],S(d,m/d)));
}
return ret=Multi(y,ret);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("test.in","r",stdin);
#endif
varphi[1]=1;
for (int i=2;i<maxn;++i) {
if (!np[i]) p[++ps]=i,varphi[i]=i-1,from[i]=1;
for (int j=1,tmp;j<=ps && (tmp=i*p[j])<maxn;++j) {
np[tmp]=true;
from[tmp]=i;
if (i%p[j]==0) {
varphi[tmp]=varphi[i]*p[j];
break;
}
varphi[tmp]=varphi[i]*varphi[p[j]];
}
}
int tmp=0;
for (int i=1;i<maxn;++i) varphi[i]=(tmp=Plus(tmp,varphi[i]));
int n,m,ans=0;
scanf("%d%d",&n,&m);
for (int i=1;i<=n;++i)
ans=Plus(ans,S(i,m));
printf("%d\n",ans);
return 0;
}