[BZOJ 2956] 模积和
[题目链接]
https://www.lydsy.com/JudgeOnline/problem.php?id=2956
[算法]
首先有两个重要的等式 :
1. 1 + 2 + 3 + 4 + ... + n = n(n + 1) / 2
2. 1 ^ 2 + 2 ^ 2 + ... + n ^ 2 = n(n + 1)(2n + 1) / 6
根据这两个式子 , 配合数论分块 , 即可在O(SQRT(N + M))的时间内求解此问题
[代码]
#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef long double ld; typedef unsigned long long ull; const int P = 19940417; const int PHI = 17091780; int n , m , A , nxt , inv , inv2; template <typename T> inline void chkmax(T &x,T y) { x = max(x,y); } template <typename T> inline void chkmin(T &x,T y) { x = min(x,y); } template <typename T> inline void read(T &x) { T f = 1; x = 0; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0'; x *= f; } inline void sub(int &x , int y) { x -= y; while (x < 0) x += P; } inline void add(int &x , int y) { x += y; while (x >= P) x -= P; } inline int exp_mod(int a , int n) { int b = a , res = 1; while (n > 0) { if (b & 1) res = 1LL * res * b % P; b = 1LL * b * b % P; n >>= 1; } return res; } inline int calc(int l , int r) { int S1 = 1LL * r * (r + 1) % P * (2 * r + 1) % P * inv2 % P , S2 = 1LL * (l - 1) * l % P * (2 * l - 1) % P * inv2 % P; return (S1 - S2 + P) % P; } int main() { read(n); read(m); inv = (P + 1) >> 1 , inv2 = 3323403; A = 1LL * m * m % P , nxt = 0; for (int i = 1; i <= m; i = nxt + 1) { nxt = m / (m / i); sub(A , 1LL * (m / i) * (i + nxt) % P * (nxt - i + 1) % P * inv % P); } nxt = 0; int res = 1LL * n * n % P * A % P; for (int i = 1; i <= n; i = nxt + 1) { nxt = n / (n / i); sub(res , 1LL * A * (n / i) % P * (i + nxt) % P * (nxt - i + 1) % P * inv % P); } if (n > m) swap(n , m); nxt = 0; int con = 1LL * n * n % P * m % P; for (int i = 1; i <= n; i = nxt + 1) { nxt = min(n / (n / i) , m / (m / i)); add(con , 1LL * (n / i) * (m / i) % P * calc(i , nxt) % P); } nxt = 0; for (int i = 1; i <= n; i = nxt + 1) { nxt = n / (n / i); sub(con , 1LL * (n / i) * m % P * (i + nxt) % P * (nxt - i + 1) % P * inv % P); } nxt = 0; for (int i = 1; i <= n; i = nxt + 1) { nxt = min(m / (m / i) , n); sub(con , 1LL * (m / i) * n % P * (i + nxt) % P * (nxt - i + 1) % P * inv % P); } sub(res , con); printf("%d\n" , res); return 0; }