P2260 [清华集训2012]模积和
题目背景
数学题,无背景。
题目描述
求
$\sum\limits_{i=1}^{n} \sum\limits_{j=1}^{m} (n \bmod i) \times (m \bmod j), i \neq j \;\bmod\;19940417$ 的值
输入输出格式
输入格式:
两个整数n m
输出格式:
答案 mod 19940417
输入输出样例
说明
30%: n,m <= 1000
60%: n,m <= 10^6
100% n,m <= 10^9
Solution:
本题实在是太贼有意思了。。。
开始没有发现条件$i\neq j$,结果一直以为取模错误,搞了半天。
首先,我们先忽略$i\neq j$的条件,直接求$\sum\limits_{i=1}^{n} \sum\limits_{j=1}^{m} (n \bmod i) \times (m \bmod j)$。
对原式化简:原式$=\sum\limits_{i=1}^{n}{(n-i\times\lfloor{n/i}\rfloor)\sum\limits_{j=1}^{m}{(m-j\times\lfloor{m/j}\rfloor)}}=(n^2-\sum\limits_{i=1}^{n}{(i\times\lfloor{n/i}\rfloor))(m^2-\sum\limits_{j=1}^{m}{(j\times\lfloor{m/j}\rfloor)})}$,然后对这个式子两边各自一遍数列分块套上等差数列求和,求出$ans$值并取模。
由于多算了$i==j$的情况,所以我们还要从$ans$中减去$i==j$的情况。
对于$i==j$的情况累加的值$tot$,容易得出$tot=\sum\limits_{i=1}^{min(n,m)}{(n-i\times\lfloor{n/i}\rfloor)\times(m-j\times\lfloor{m/i}\rfloor)}=\sum\limits_{i=1}^{min(n,m)}{n\times m+i^2\times\lfloor{n/i\times m/i}\rfloor-n\times i\times \lfloor{m/i}\rfloor-m\times i\times\lfloor{n/i}\rfloor}$
然后对于$n\times m$直接累加,对于$n\times i\times \lfloor{m/i}\rfloor\;,\;m\times i\times\lfloor{n/i}\rfloor$还是数论分块套上等差数列求和。
难的是求$i^2\times\lfloor{n/i\times m/i}\rfloor$,此时应用一个数学公式:$1^2+2^2+…+x^2=\frac{x\times(x+1)\times(2\times x+1)}{6}$,设$c=min(n/(n/i),m/(m/i))$,那么$i^2+{(i+1)}^2+…+c^2=\frac{c\times(c+1)\times(2\times c+1)}{6}-\frac{(i-1)\times((i-1)+1)\times(2\times (i-1)+1)}{6}$,这样原式就能直接公式求了,但是因为存在取模的条件,所以此时骚操作是处理出$6$的因子并约掉,或者直接预先算出$6$关于模数的逆元$inv$。
最后输出$ans$就好了。
代码:
#include<bits/stdc++.h> #define il inline #define ll long long #define For(i,a,b) for(int (i)=(a);(i)<=(b);(i)++) #define Bor(i,a,b) for(int (i)=(b);(i)>=(a);(i)--) #define Max(a,b) ((a)>(b)?(a):(b)) #define Min(a,b) ((a)>(b)?(b):(a)) using namespace std; const int mod = 19940417 , inv = 3323403; ll n,m; il ll solve(ll x){ ll ans=(x%mod*x%mod)%mod,p,c; for(ll i=1;i<=x;i=p+1){ p=x/(x/i); ans=(ans-(p+i)*(p-i+1)/2%mod*(x/i)%mod+mod)%mod; } return ans; } il ll get(ll x){return x*(x+1)%mod*(x<<1|1)%mod*inv%mod;} int main(){ ios::sync_with_stdio(0); cin>>n>>m; ll p,sum1,sum2,sum3,ans=solve(n)*solve(m)%mod; if(n>m)swap(n,m); for(ll i=1;i<=n;i=p+1){ p=Min(n/(n/i),m/(m/i)); sum1=(m*n%mod*(p-i+1))%mod; sum2=(n/i)*(m/i)%mod*(get(p)-get(i-1)+mod)%mod; sum3=(p-i+1)*(p+i)/2%mod*(n/i*m%mod+m/i*n%mod); ans=(ans-(sum1+sum2-sum3)%mod+mod)%mod; } cout<<ans%mod; return 0; }