浏览器标题切换
浏览器标题切换end
把博客园图标替换成自己的图标
把博客园图标替换成自己的图标end

BZOJ2956: 模积和

Description

 求∑∑((n mod i)*(m mod j))其中1<=i<=n,1<=j<=m,i≠j。

  

Input

第一行两个数n,m。

Output

  一个整数表示答案mod 19940417的值

Sample Input

3 4

Sample Output

1

样例说明
  答案为(3 mod 1)*(4 mod 2)+(3 mod 1) * (4 mod 3)+(3 mod 1) * (4 mod 4) + (3 mod 2) * (4 mod 1) + (3 mod 2) * (4 mod 3) + (3 mod 2) * (4 mod 4) + (3 mod 3) * (4 mod 1) + (3 mod 3) * (4 mod 2) + (3 mod 3) * (4 mod 4) = 1

数据规模和约定
  对于100%的数据n,m<=10^9。

Solution

题目就是求

\[∑_{i=1}^n∑_{j=1}^m[i≠j](n\space mod\space i)(m\space mod\space j) \]

先讨论不考虑i≠j的限制条件的情况

\[\large \begin{align*} &\sum_{i=1}^n\sum_{j=1}^m(n\space mod\space i)(m\space mod\space j)\\ &=\sum\sum{(n-\frac{n}{i}*i)(m-\frac{m}{j}*j)}\\ &=\sum_{i=1}^{n}\sum_{j=1}^{m}{nm-\frac{n}{i}*i*m-n*\frac{m}{j}*j+i*j*\frac{n}{i}*\frac{m}{j}}\\ &=n^2m^2-nm^2\sum_{i=1}^{n}{\frac{n}{i}*i}-n^2*m\sum_{j=1}^m{\frac{m}{j}*j}+nm\sum_{i=1}^{n}{i*\frac{n}{i}*}\sum_{j=1}^{m}{j*\frac{m}{j}} \end{align*} \]

这是一种方法

然而还有更简便的方法

\[\large \sum{n\space mod\space i}*\sum{m\space mod\space j} \]

直接用余数之和那题的方法求这个就好(不知道余数之和那题怎么写的戳这里

就不用上面一大堆码起来也麻烦的式子了

对于i==j的情况

\[\large \begin{align*} &\sum_{i=1}^{k=min(n,m)}{(n-\frac{n}{i}*i)(m-\frac{m}{i}*i)}[i==j]\\ &=\sum_{i=1}^{k}{nm-m*\frac{n}{i}*i-n*\frac{m}{i}*i+i^2*\frac{n}{i}*\frac{m}{i}}\\ &=knm-km\sum_{i=1}^{k}{\frac{n}{i}*i}-kn\sum_{i=1}^{k}{\frac{m}{i}*i}+k\sum_{i=1}^{k}{i^2}\sum_{i=1}^{k}{\frac{n}{i}}\sum_{i=1}^{k}{\frac{m}{i}} \end{align*} \]

利用数论分块\(O(\sqrt{n})\)求出上面两式,将两式相减即可

P.S:\(\sum_{i=1}^n{i^2}=\frac{n*(n+1)*(2n+1)}{6}\)

#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define N 2010
#define mod 19940417
const ll m6 = 3323403;
ll n, m;
ll ans = 0;

ll sum(ll l, ll r) {
	return (r - l + 1) * (l + r) / 2 % mod;
}

ll calc(ll k) {
	ll ans = k * k % mod;
	for(int l = 1, r; l <= k; l = r + 1) {
		r = k / (k / l);
		ans = ((ans - sum(l, r) * (k / l) % mod) % mod + mod) % mod; 
	} 
	return ans;
}

ll cal(ll x) {
	return x * (x + 1) % mod * (2 * x + 1) % mod * m6 % mod;
}

ll sum2(ll l, ll r) {
	return (cal(r) - cal(l - 1) + mod) % mod;
} 

int main() {
	scanf("%lld%lld", &n, &m);
	if(n > m) swap(n, m);
	ans = calc(n) * calc(m) % mod;
	ans = ((ans - n * n % mod * m % mod) % mod + mod) % mod; 
	for(int l = 1, r; l <= n; l = r + 1) {
		r = min(n / (n / l), m / (m / l));
		ans = (ans + sum(l, r) * ((n/l)*m % mod + (m/l)*n % mod) % mod % mod);
		ans = (ans - sum2(l, r) * (n/l) % mod * (m/l) % mod + mod) % mod;
	}
	printf("%lld\n", (ans % mod + mod) % mod);
	return 0;
}
posted @ 2018-12-30 23:23  henry_y  阅读(858)  评论(0编辑  收藏  举报