Luogu3768 简单的数学题
题面
题解
$$ \sum_{i=1}^n\sum_{j=1}^n(ijgcd(i,j)) \\ =\sum_{i=1}^n\sum_{j=1}^n\left(ij\sum_{d|i,d|j}\varphi(d)\right) \\ =\sum_{d=1}^nd^2\varphi(d)S\left(\left\lfloor\frac nd\right\rfloor\right)^2 \; \left(S(x)=\sum_{i=1}^xi=\frac{n(n+1)}2\right) \\ \because f=id^2\cdot\varphi,\;\therefore g=id^2\cdot1 \\ \therefore g(x)=x^2 \\ \therefore S(n)=\sum_{i=1}^ni^3-\sum_{i=1}^ni^2S\left(\left\lfloor\frac ni\right\rfloor\right) \\ =\left(\sum_{i=1}^ni\right)^2-\sum_{i=1}^ni^2S\left(\left\lfloor\frac ni\right\rfloor\right)\;\left(\sum_{i=1}^ni^2=\frac{n(n+1)(2n+1)}6\right) $$
于是就很好杜教筛了
代码
#include<cstdio>
#include<cstring>
#include<cctype>
#include<algorithm>
#include<map>
#define RG register
#define file(x) freopen(#x".in", "r", stdin);freopen(#x".out", "w", stdout);
#define clear(x, y) memset(x, y, sizeof(x))
inline long long read()
{
long long data = 0, w = 1; char ch = getchar();
while(ch != '-' && (!isdigit(ch))) ch = getchar();
if(ch == '-') w = -1, ch = getchar();
while(isdigit(ch)) data = data * 10 + (ch ^ 48), ch = getchar();
return data * w;
}
const int maxn(8000010);
long long Mod, n, inv2, inv6, phi[maxn + 10];
int prime[maxn], cnt;
bool not_prime[maxn + 10];
std::map<long long, long long> M;
inline long long fastpow(long long x, long long y)
{
long long ans = 1;
while(y)
{
if(y & 1) ans = ans * x % Mod;
x = x * x % Mod; y >>= 1;
}
return ans;
}
long long Sum1(long long x) { x %= Mod; return x * (x + 1) % Mod * inv2 % Mod; }
long long Sum2(long long x)
{ x %= Mod; return x * (x + 1) % Mod * (x + x + 1) % Mod * inv6 % Mod; }
void Init()
{
not_prime[1] = true; phi[1] = 1;
for(RG int i = 2; i <= maxn; i++)
{
if(!not_prime[i]) prime[++cnt] = i, phi[i] = i - 1;
for(RG int j = 1; j <= cnt && i * prime[j] <= maxn; j++)
{
not_prime[i * prime[j]] = true;
if(i % prime[j]) phi[i * prime[j]] = 1ll * phi[i] * phi[prime[j]] % Mod;
else { phi[i * prime[j]] = 1ll * phi[i] * prime[j] % Mod; break; }
}
}
for(RG int i = 1; i <= maxn; i++)
phi[i] = (phi[i - 1] + 1ll * phi[i] * i % Mod * i % Mod) % Mod;
}
long long Sum(long long x)
{
if(x <= maxn) return phi[x];
if(M.find(x) != M.end()) return M[x];
long long ans = Sum1(x); ans = ans * ans % Mod;
for(long long i = 2, j; i <= x; i = j + 1)
{
j = x / (x / i);
long long tmp = (Sum2(j) - Sum2(i - 1)) % Mod;
ans -= Sum(x / i) * tmp % Mod;
ans %= Mod;
}
return M[x] = (ans + Mod) % Mod;
}
int main()
{
#ifndef ONLINE_JUDGE
file(cpp);
#endif
Mod = read(), n = read(), inv2 = fastpow(2, Mod - 2),
inv6 = fastpow(6, Mod - 2);
Init(); long long ans = 0;
for(long long i = 1, j; i <= n; i = j + 1)
{
j = n / (n / i);
long long tmp = Sum1(n / i); tmp = tmp * tmp % Mod;
long long Tmp = (Sum(j) - Sum(i - 1)) % Mod;
(ans += Tmp * tmp % Mod) %= Mod;
}
printf("%lld\n", (ans + Mod) % Mod);
return 0;
}