SWJTU2017-6月月赛 C-H1Z1[数论][乘法逆元]

传送门:http://www.swjtuoj.cn/problem/2393/

题意:计算nm的每个点到n*m每个位置的曼哈顿距离和

题解:考虑先计算每个点到x方向的距离和。设当前点为(X,Y),因为每一行在x方向距离和相同,所以只需算一行的距离然后乘n行。一行x方向距离和公式为两个等差数列的和$\frac{{(m - x)(m - x + 1) + (x - 1)x}}{2}$

由于要计算每个点的距离和,和之前的方法一样,计算一行的然后乘n行。${\text{dis = }}\frac{{{n^2}}}{2}\sum\limits_{x = 1}^m {(m - x) + (m - x + 1) + (x - 1)x} $

相加的两项求和最后都为1*0+2*1+...+m*(m-1) 求n平方和公式为$\frac{{n(n + 1)(2n + 1)}}{6}$ 再减一个等差即可 化简后为$\frac{{m({m^2} - 1)}}{3}$

最后x轴距离和为 $\frac{{{n^2}m({m^2} - 1)}}{3}$ y轴只需要把mn对换

答案为$\frac{{{n^2}m({m^2} - 1)}}{3} + \frac{{{m^2}n({n^2} - 1)}}{3}$

除以3划为乘3关于mod的逆元pow(3,mod-2)

代码如下:

 1 #define _CRT_SECURE_NO_DEPRECATE
 2 #pragma comment(linker, "/STACK:102400000,102400000")
 3 #include<iostream>  
 4 #include<cstdio>  
 5 #include<fstream>  
 6 #include<iomanip>
 7 #include<algorithm>  
 8 #include<cmath>  
 9 #include<deque>  
10 #include<vector>  
11 #include<assert.h>
12 #include<bitset>
13 #include<queue>  
14 #include<string>  
15 #include<cstring>  
16 #include<map>  
17 #include<stack>  
18 #include<set>
19 #include<functional>
20 #define pii pair<int, int>
21 #define mod 1000000007
22 #define mp make_pair
23 #define pi acos(-1)
24 #define eps 0.00000001
25 #define mst(a,i) memset(a,i,sizeof(a))
26 #define all(n) n.begin(),n.end()
27 #define lson(x) ((x<<1))  
28 #define rson(x) ((x<<1)|1) 
29 #define inf 0x3f3f3f3f
30 typedef long long ll;
31 typedef unsigned long long ull;
32 using namespace std;
33 ll poww(ll m, int n)
34 {
35     ll ans = 1;
36     ll temp = m%mod;
37     while (n)
38     {
39         if (n & 1)
40             ans *= temp;
41         ans %= mod;
42         temp *= temp;
43         temp %= mod;
44         n >>= 1;
45     }
46     return ans%mod;
47 }
48 int main()
49 {
50     ios::sync_with_stdio(false);
51     cin.tie(0); cout.tie(0);
52     int i, j, k;
53     ll n, m, T;
54     cin >> T;
55     while(T--)
56     {
57         cin >> n >> m;
58         ll ta = (n*n) % mod, tb = (m*m) % mod;
59         ll tc = (ta*m) % mod, td = (tb*n) % mod;
60         ll ans = (tc *(tb - 1)) % mod + (td*(ta - 1)) % mod;
61         ans = (ans*poww(3, mod - 2)) % mod;
62         cout << ans << endl;
63     }
64     return 0;
65 }

 

posted @ 2017-06-11 03:06  Meternal  阅读(294)  评论(0编辑  收藏  举报