题解 pyh的求和
题目
思路
不妨设 \(n\le m\)
设 \(T=dk\)
设 \(f(n)=\sum\limits_{k\mid n}\mu(\frac{T}{k})\frac{k}{\varphi(k)}\), \(s(k,n)=\sum\limits_{i=1}^{n}\varphi(ik)\)
先在 \(O(n)\) 的时间复杂度内先线性筛出 \(\mu(i)\) 和 \(\varphi(i)\) 的值,再通过埃氏筛的方法在 \(O(n\ln n)\) 的时间复杂度内求出 \(g(k)\) 和 \(s(k,n)\) 的值
但是,我们显然不能接受单次 \(O(n)\) 的时间复杂度,因此我们考虑再次记忆化。
设 \(T(n,m,k)=\sum\limits_{T=1}^{k}g(T)\cdot s(T,n)\cdot s(T,m)\) ,预处理 \(T\) 即可,但我们发现这样做的时间复杂度和空间复杂度也是我们无法接收的,因此我们可以考虑分段处理。
我们设定一个阈值 \(top\) ,对于 \(i\le \left\lfloor{\frac{m}{top}}\right\rfloor\) 的部分我们暴力计算,对于 \(i>\left\lfloor{\frac{m}{top}}\right\rfloor\) 的部分我们预处理。又因为当 \(i>\left\lfloor{\frac{m}{top}}\right\rfloor\) 时 \(\left\lfloor{\frac{m}{i}}\right\rfloor\le top\) ,因此我们预处理的时间复杂度就控制在了 \(O(n+n\ln n+top n)\) 内,处理的时间复杂度也就在 \(O(\frac{m}{top}+\sqrt{n})\) 内,只要选一个合适的 \(top\) 就可以了。
PS: 本题空间卡的很紧,因此不能开 long long
,需要强制类型转换
代码
#include<iostream>
#include<cstdio>
#include<vector>
#define mod 998244353
using namespace std;
const int maxn=100000,top=30;
int n,m;
bool vis[maxn+5];
int cnt,p[maxn+5];
int phi[maxn+5],inv_phi[maxn+5],mu[maxn+5];
int inv(int x,int y) {
int res=1;
while(y>0) {
if(y&1)res=((long long)res*x)%mod;
x=((long long)x*x)%mod;
y=(y>>1);
}
return res;
}
int g[maxn+5];
vector<int>s[maxn+5];
vector<int>T[top+5][top+5];
void pre() {
inv_phi[1]=phi[1]=mu[1]=1;
for(int i=2; i<=maxn; i++) {
if(!vis[i])p[++cnt]=i,phi[i]=i-1,mu[i]=-1;
inv_phi[i]=inv(phi[i],mod-2);
for(int j=1; j<=cnt&&i*p[j]<=maxn; j++) {
vis[i*p[j]]=true;
if(i%p[j]==0) {
phi[i*p[j]]=phi[i]*p[j];
mu[i*p[j]]=0;
break;
} else {
phi[i*p[j]]=phi[i]*phi[p[j]];
mu[i*p[j]]=mu[i]*mu[p[j]];
}
}
}
for(int i=1; i<=maxn; i++) {
s[i].push_back(0);
for(int j=i; j<=maxn; j+=i) {
s[i].push_back((s[i][j/i-1]+phi[j])%mod);
g[j]=(g[j]+(long long)mu[j/i]*i*inv_phi[i]%mod)%mod;
}
}
for(int i=1; i<=top; i++) {
for(int j=1; j<=top; j++) {
T[i][j].push_back(0);
for(int k=1;k<=maxn/i&&k<=maxn/j;k++){
T[i][j].push_back((T[i][j][k-1]+(long long)s[k][i]*s[k][j]%mod*g[k]%mod)%mod);
}
}
}
}
int main() {
pre();
int t;
cin>>t;
while(t--) {
scanf("%d%d",&n,&m);
if(n>m)swap(n,m);
int ans=0;
for(int i=1;i<=m/top;i++){
ans=(ans+(long long)s[i][n/i]*s[i][m/i]%mod*g[i]%mod);
}
for(int l=m/top+1,r; l<=n; l=r+1) {
r=min(n/(n/l),m/(m/l));
ans=(ans+(T[n/l][m/l][r]-T[n/l][m/l][l-1]+mod)%mod)%mod;
}
printf("%d\n",ans);
}
return 0;
}