[Luogu] P2260 [清华集训2012]模积和
Description
求\(\sum\limits_{i=1}^n\sum\limits_{j=1}^m(n\bmod{i})\times(m\bmod{j}),i\ne{j}\pmod{19940417}\)
Solution
不妨设\(n\le{m}\),那么
\(\sum\limits_{i=1}^n\sum\limits_{j=1}^m(n\bmod{i})\times(m\bmod{j}),i\ne{j}\)
\(=\sum\limits_{i=1}^n\sum\limits_{j=1}^m(n\bmod{i})\times(m\bmod{j})-\sum\limits_{i=1}^n(n\bmod{i})\times(m\bmod{i})\)
\(=\sum\limits_{i=1}^n(n-i\lfloor\frac{n}{i}\rfloor)\sum\limits_{j=1}^m(m-i\lfloor\frac{m}{j}\rfloor)-\sum\limits_{i=1}^n(n-i\lfloor\frac{n}{i}\rfloor)(m-i\lfloor\frac{m}{i}\rfloor)\)
前面的一坨直接整除分块,就是后面的有点不太好搞。直接拆开:
\(\sum\limits_{i=1}^n(n-i\lfloor\frac{n}{i}\rfloor)(m-i\lfloor\frac{m}{i}\rfloor\)
\(=\sum\limits_{i=1}^n(nm-i\lfloor\frac{n}{i}\rfloor{m}-i\lfloor\frac{m}{i}\rfloor{n}+i^2\lfloor\frac{n}{i}\rfloor\lfloor\frac{m}{i}\rfloor)\)
那不也是整除分块嘛。然后用平方和公式\(\sum\limits_{i=1}^ni^2=\frac{n(n+1)(2n+1)}{6}\)就可以了。
注意模数不是质数,不能用逆元,所以要用欧拉定理求出\(inv_6=3323403\)。
一定要多取模,不然会爆!
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const ll mod = 19940417;
const ll inv6 = 3323403;
ll n, m, mul1, mul2, res;
ll read()
{
ll x = 0ll, fl = 1ll; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') fl = -1ll; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = (x << 1ll) + (x << 3ll) + ch - '0'; ch = getchar();}
return x * fl;
}
int main()
{
n = read(); m = read();
for (ll l = 1, r; l <= n; l = r + 1)
{
r = min(n, n / (n / l));
mul1 = (mul1 + 1ll * n % mod * (r - l + 1ll) % mod - (n / l) % mod * ((1ll * (r - l + 1) * (l + r) / 2ll) % mod) % mod + 10ll * mod) % mod;
}
for (ll l = 1, r; l <= m; l = r + 1)
{
r = min(m, m / (m / l));
mul2 = (mul2 + 1ll * m % mod * (r - l + 1ll) % mod - (m / l) % mod * ((1ll * (r - l + 1) * (l + r) / 2ll) % mod) % mod + 10ll * mod) % mod;
}
res = mul1 * mul2 % mod;
ll d = min(n, m);
for (ll l = 1, r; l <= d; l = r + 1)
{
r = min(d, min(n / (n / l), m / (m / l)));
res = (res - 1ll * (r - l + 1) % mod * n % mod * m % mod + 1ll * (n / l) % mod * m % mod * ((1ll * (r - l + 1) * (l + r) / 2ll) % mod) % mod + 1ll * (m / l) % mod * n % mod * ((1ll * (r - l + 1) * (l + r) / 2ll) % mod) % mod - 1ll * (n / l) % mod * (m / l) % mod * (((1ll * r % mod * (r + 1) % mod * (2ll * r + 1ll) % mod * inv6 % mod - 1ll * l % mod * (l - 1) % mod * (2ll * l - 1ll) % mod * inv6 % mod + 10ll * mod) % mod) + 10ll * mod) % mod + 10ll * mod) % mod;
}
printf("%lld\n", (res + 10ll * mod) % mod);
return 0;
}