BZOJ2956: 模积和
Description
求∑∑((n mod i)*(m mod j))其中1<=i<=n,1<=j<=m,i≠j。
Input
第一行两个数n,m。
Output
一个整数表示答案mod 19940417的值
Sample Input
3 4
Sample Output
1
样例说明
答案为(3 mod 1)*(4 mod 2)+(3 mod 1) * (4 mod 3)+(3 mod 1) * (4 mod 4) + (3 mod 2) * (4 mod 1) + (3 mod 2) * (4 mod 3) + (3 mod 2) * (4 mod 4) + (3 mod 3) * (4 mod 1) + (3 mod 3) * (4 mod 2) + (3 mod 3) * (4 mod 4) = 1
数据规模和约定
对于100%的数据n,m<=10^9。
Solution
题目就是求
\[∑_{i=1}^n∑_{j=1}^m[i≠j](n\space mod\space i)(m\space mod\space j)
\]
先讨论不考虑i≠j的限制条件的情况
\[\large
\begin{align*}
&\sum_{i=1}^n\sum_{j=1}^m(n\space mod\space i)(m\space mod\space j)\\
&=\sum\sum{(n-\frac{n}{i}*i)(m-\frac{m}{j}*j)}\\
&=\sum_{i=1}^{n}\sum_{j=1}^{m}{nm-\frac{n}{i}*i*m-n*\frac{m}{j}*j+i*j*\frac{n}{i}*\frac{m}{j}}\\
&=n^2m^2-nm^2\sum_{i=1}^{n}{\frac{n}{i}*i}-n^2*m\sum_{j=1}^m{\frac{m}{j}*j}+nm\sum_{i=1}^{n}{i*\frac{n}{i}*}\sum_{j=1}^{m}{j*\frac{m}{j}}
\end{align*}
\]
这是一种方法
然而还有更简便的方法
\[\large
\sum{n\space mod\space i}*\sum{m\space mod\space j}
\]
直接用余数之和那题的方法求这个就好(不知道余数之和那题怎么写的戳这里)
就不用上面一大堆码起来也麻烦的式子了
对于i==j的情况
\[\large
\begin{align*}
&\sum_{i=1}^{k=min(n,m)}{(n-\frac{n}{i}*i)(m-\frac{m}{i}*i)}[i==j]\\
&=\sum_{i=1}^{k}{nm-m*\frac{n}{i}*i-n*\frac{m}{i}*i+i^2*\frac{n}{i}*\frac{m}{i}}\\
&=knm-km\sum_{i=1}^{k}{\frac{n}{i}*i}-kn\sum_{i=1}^{k}{\frac{m}{i}*i}+k\sum_{i=1}^{k}{i^2}\sum_{i=1}^{k}{\frac{n}{i}}\sum_{i=1}^{k}{\frac{m}{i}}
\end{align*}
\]
利用数论分块\(O(\sqrt{n})\)求出上面两式,将两式相减即可
P.S:\(\sum_{i=1}^n{i^2}=\frac{n*(n+1)*(2n+1)}{6}\)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define N 2010
#define mod 19940417
const ll m6 = 3323403;
ll n, m;
ll ans = 0;
ll sum(ll l, ll r) {
return (r - l + 1) * (l + r) / 2 % mod;
}
ll calc(ll k) {
ll ans = k * k % mod;
for(int l = 1, r; l <= k; l = r + 1) {
r = k / (k / l);
ans = ((ans - sum(l, r) * (k / l) % mod) % mod + mod) % mod;
}
return ans;
}
ll cal(ll x) {
return x * (x + 1) % mod * (2 * x + 1) % mod * m6 % mod;
}
ll sum2(ll l, ll r) {
return (cal(r) - cal(l - 1) + mod) % mod;
}
int main() {
scanf("%lld%lld", &n, &m);
if(n > m) swap(n, m);
ans = calc(n) * calc(m) % mod;
ans = ((ans - n * n % mod * m % mod) % mod + mod) % mod;
for(int l = 1, r; l <= n; l = r + 1) {
r = min(n / (n / l), m / (m / l));
ans = (ans + sum(l, r) * ((n/l)*m % mod + (m/l)*n % mod) % mod % mod);
ans = (ans - sum2(l, r) * (n/l) % mod * (m/l) % mod + mod) % mod;
}
printf("%lld\n", (ans % mod + mod) % mod);
return 0;
}