ACM学习历程—BZOJ2956 模积和(数论)

Description

 求∑∑((n mod i)*(m mod j))其中1<=i<=n,1<=j<=m,i≠j。

  

Input

第一行两个数n,m。

Output

  一个整数表示答案mod 19940417的值

Sample Input


3 4

Sample Output

1

样例说明
  答案为(3 mod 1)*(4 mod 2)+(3 mod 1) * (4 mod 3)+(3 mod 1) * (4 mod 4) + (3 mod 2) * (4 mod 1) + (3 mod 2) * (4 mod 3) + (3 mod 2) * (4 mod 4) + (3 mod 3) * (4 mod 1) + (3 mod 3) * (4 mod 2) + (3 mod 3) * (4 mod 4) = 1

数据规模和约定
  对于100%的数据n,m<=10^9。

 

由题意:

∑∑((n mod i) * (m mod j)) ( i≠j)= ∑(n mod i) * ∑(m mod i) - ∑((n mod i) * (m mod i))= ∑(n-[n/i]*i) * ∑(m-[m/i]*i) - ∑(nm-([n/i]+[m/i])i+[n/i][m/i]*i*i)= ∑(n-[n/i]*i) * ∑(m-[m/i]*i) – n*n*m+∑[n/i]i+∑[m/i]i-∑[n/i][m/i]*i*i(n <= m)

然后利用[n/i]的分组加速运算即可,不过中间过程有多处需要注意的,

m/(m/i)的时候需要和n比较大小,因为可能会超出范围。

此外就是int乘法可能会爆,需要转long long,中间过程别忘了MOD

 

代码:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <set>
#include <map>
#include <queue>
#include <string>
#define LL long long
#define MOD 19940417
#define nsix 3323403

using namespace std;

int n, m;

LL cal(int len, int x)
{
    LL ans = 0, tmp;
    int j;
    for (int i = 1; i <= len; ++i)
    {
        j = min(len, x/(x/i));//这一句不用min,j会越界
        tmp = ((LL)i+j)*(j-i+1)/2%MOD;
        ans += tmp*(x/i)%MOD;
        ans %= MOD;
        i = j;
    }
    return ans;
}

inline LL sum(LL x)
{
    return x*(x+1)%MOD*(2*x+1)%MOD*nsix%MOD;
}

LL cal2(int x, int y)
{
    LL ans = 0, tmp, ttmp;
    int j;
    for (int i = 1; i <= x; ++i)
    {
        j = min(x/(x/i), y/(y/i));
        //j = min(j, x);
        tmp = sum(j)-sum(i-1);
        tmp = (tmp%MOD+MOD)%MOD;
        ttmp = ((LL)x/i)*(y/i)%MOD;
        ans += tmp*ttmp%MOD;
        ans %= MOD;
        i = j;
    }
    return ans;
}

void work()
{
    if (n > m) swap(n, m);
    LL ans, m2, n2, snn, smm, snm, ss;
    m2 = (LL)m*m%MOD;
    n2 = (LL)n*n%MOD;
    smm = cal(m, m);
    snn = cal(n, n);
    snm = cal(n, m);
    ss = cal2(n, m);
    ans = m2*n2%MOD - m2*snn%MOD - n2*smm%MOD + snn*smm%MOD;
    ans -= m*n2%MOD;
    ans += m*snn%MOD;
    ans += n*snm%MOD;
    ans -= ss;
    ans = (ans%MOD+MOD)%MOD;
    printf("%lld\n", ans);
}

int main()
{
    //freopen("test.in", "r", stdin);
    while (scanf("%d%d", &n, &m) != EOF)
        work();
    return 0;
}

 

posted on 2015-10-25 16:19  AndyQsmart  阅读(197)  评论(0编辑  收藏  举报

导航