[BZOJ3518] 点组计数
[BZOJ3518] 点组计数
Description
平面上摆放着一个nm的点阵(下图所示是一个34的点阵)。Curimit想知道有多少三点组(a,b,c)满足以a,b,c三点共线。这里a,b,c是不同的3个点,其顺序无关紧要。(即(a,b,c)和(b,c,a)被认为是相同的)。由于答案很大,故你只需要输出答案对1,000,000,007的余数就可以了。
Input
有且仅有一行,两个用空格隔开的整数n和m。
Output
有且仅有一行,一个整数,表示三点组的数目对1,000,000,007的余数。(1,000。000。007是质数)
Sample Input
3 4
Sample Output
2 0
HINT
对于100%的数据,1< =N.m< =50000
试题分析
由于两点确定一条直线,所以常规思想就是枚举一条线段,然后计算中间有多少点。
因为在坐标系中竖线和横线特殊,所以我们单拿出来考虑,就是:$$\binom {n}{3} \times m + \binom {m}{3} \times n$$
那么怎么求一条线段中间有多少点呢?假设枚举到的两个点为\(p_1\),\(p_3\),它们的横、纵坐标差为\(i\),\(j\)。
然后还要设中间的点坐标为\(p_2\),它们的横、纵坐标差为\(x\),\(y\)。
显然斜率唯一,所以斜率要相等,也就是:$$\frac{j}{i}=\frac{y}{x}$$
于是\(\frac{j}{gcd(i,j)} | y\),\(y\leq j\),所以\(cnt_y=\frac{j}{\frac{j}{gcd(i,j)}}\)。
由于每一个\(y\)对应一个\(x\),所以不需要讨论\(x\)了。
那么能取的合法的点就是\(gcd(i,j)-1\)个(还要减去\(y=i\)的情况)
那么就是求:$$ans=2\times \sum_{i=1}{n-1}\sum_{j=1} (n-i)(m-j)(gcd(i,j)-1)$$
根据欧拉反演有:$$ans=2\times \sum_{i=1}{n-1}\sum_{j=1} (n-i)(m-j)(\sum_{d|i,d|j} \phi(d) -1)$$
然后变成枚举\(d\):$$ans=2\times(\sum_\limits{d=1}^{min(n-1,m-1)} \phi(d) \sum_\limits{i=1}^{\lfloor \frac{n-1}{d} \rfloor} \sum_\limits{j=1}^{\lfloor \frac{m-1}{d} \rfloor} (n-id)(n-jd) - \sum\limits_{i=1}n\sum\limits_{j=1}m (n-i)(m-j))$$
枚举计算即可。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
#define LL long long
inline LL read(){
LL x=0,f=1; char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
return x*f;
}
const LL MAXN = 200010;
const LL INF = 2147483600;
const LL Mod = 1000000007;
const LL inv2 = 500000004;
const LL inv6 = 166666668;
LL N,M; LL pri[MAXN+1],phi[MAXN+1],cnt; bool vis[MAXN+1];
inline void init(){
vis[1]=1; phi[1]=1;
for(LL i=2;i<MAXN;i++){
if(!vis[i]) pri[++cnt]=i,phi[i]=i-1;
for(LL j=1;j<=cnt&&pri[j]*i<MAXN;j++){
vis[i*pri[j]]=true;
if(i%pri[j]==0) {phi[pri[j]*i]=pri[j]*phi[i]; break;}
phi[pri[j]*i]=(pri[j]-1)*phi[i];
}
} return ;
}
inline LL cal(LL n,LL k,LL d){
return 1LL*(k*n%Mod-1LL*d*(k+1)%Mod*k%Mod*inv2%Mod+Mod)%Mod;
}
inline LL C(LL b){
return b*(b-1)%Mod*(b-2)%Mod*inv6%Mod;
}
int main(){
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
N=read(),M=read(); init(); LL ans=0;
if(N>M) swap(N,M);
for(LL i=1;i<=N;i++){
ans=ans+phi[i]*cal(M,M/i,i)%Mod*cal(N,N/i,i)%Mod;
ans%=Mod;
} ans=(ans-((N-1)*N%Mod*inv2%Mod*(M-1)*M%Mod*inv2%Mod)+Mod)%Mod;
printf("%lld\n",(2*ans+C(M)%Mod*N%Mod+C(N)%Mod*M%Mod)%Mod);
return 0;
}