【BZOJ3518】点组计数
Description
平面上摆放着一个\(n*m\)的点阵(下图所示是一个3*4的点阵)。Curimit想知道有多少三点组(a,b,c)满足以a,b,c三点共线。这里a,b,c是不同的3个点,其顺序无关紧要。(即(a,b,c)和
(b,c,a)被认为是相同的)。由于答案很大,故你只需要输出答案对1,000,000,007的余数就可以了。
\(1\le n,m\le50000\).
Solution
首先是横着的和竖着的点组,总共有\(m{n \choose 3}+n{m\choose 3}\)种方案。
接下来是斜着的点组,怎么统计?
我想复杂了,想枚举A点和B点,然后作一条射线,统计射线上有多少个点......
更加简洁的做法是:枚举A点和C点,然后看两点的连线经过多少个点。现在只考虑A在C的左上角的情况,A在C右上角的情况同理,只需要对答案乘2。
我们枚举\(C\)相对\(A\)偏移值\(x\)和\(y\),若\(A(i,j)\),则\(C(i+x,j+y)\)。这样的\(AC\)组合会出现\((n-x)(m-y)\)次。我们发现,\(AC\)连线经过的点数恰好是\(\gcd(x,y)-1\),因此左上往右下斜的答案就是:
\[\begin{aligned}
S&=\sum_{x=1}^{n-1}\sum_{y=1}^{m-1}(n-x)(m-y)(\gcd(x,y)-1)\\
S+\sum_{x=1}^{n-1}\sum_{y=1}^{m-1}(n-x)(m-y)&=\sum_{x=1}^{n-1}\sum_{y=1}^{m-1}(n-x)(m-y)\gcd(x,y)\\
\end{aligned}
\]
左边的和式可以直接算,我们关心右边:这里用到了欧拉函数的性质\(\sum_{d|n}\varphi(d)=n\):
\[\begin{aligned}
右边&=\sum_{x=1}^{n-1}\sum_{y=1}^{m-1}(n-x)(m-y)\sum_{d|\gcd(x,y)}\varphi(d)\\
&=\sum_{d=1}^{\min(n-1,m-1)}\varphi(d)\sum_{i=1}^{\lfloor\frac{n-1}d\rfloor}(n-i*d)\sum_{j=1}^{\lfloor\frac{m-1}d\rfloor}(m-j*d)
\end{aligned}
\]
枚举\(d\),用等差数列求和\(O(1)\)算出后面的部分,于是这一段部分也就计算完了!那么总答案就是:
\[Ans=m{n \choose 3}+n{m\choose 3}+S
\]
Code
#include <cstdio>
using namespace std;
const int N=50005,MOD=1e9+7,INV2=500000004;
int n,m;
bool vis[N];
int p[N],pcnt,phi[N];
int fact[N],invfact[N];
inline int min(int x,int y){return x<y?x:y;}
inline int max(int x,int y){return x>y?x:y;}
inline int ksm(int x,int y){
int res=1;
for(;y;x=1LL*x*x%MOD,y>>=1)
if(y&1) res=1LL*res*x%MOD;
return res;
}
void sieve(){
phi[1]=1;
for(int i=2;i<N;i++){
if(!vis[i]){
p[++pcnt]=i;
phi[i]=i-1;
}
for(int j=1;j<=pcnt&&i*p[j]<N;j++){
int x=i*p[j];
vis[x]=true;
if(i%p[j]==0){
phi[x]=phi[i]*p[j];
break;
}
phi[x]=phi[i]*(p[j]-1);
}
}
}
inline int C(int n,int m){
if(m<0||m>n) return 0;
return 1LL*fact[n]*invfact[m]%MOD*invfact[n-m]%MOD;
}
inline int calc(int a,int d){
return (1LL*((a-1)/d)*a%MOD-1LL*INV2*d%MOD*((a-1)/d+1)%MOD*((a-1)/d)%MOD)%MOD;
}
int main(){
sieve();
scanf("%d%d",&n,&m);
int maxs=max(n,m),mins=min(n-1,m-1);
fact[0]=1;
for(int i=1;i<=maxs;i++) fact[i]=1LL*fact[i-1]*i%MOD;
invfact[maxs]=ksm(fact[maxs],MOD-2);
for(int i=maxs-1;i>=0;i--) invfact[i]=1LL*invfact[i+1]*(i+1)%MOD;
int ans=0;
for(int d=1;d<=mins;d++)
(ans+=1LL*phi[d]*calc(n,d)%MOD*calc(m,d)%MOD)%=MOD;
for(int i=1;i<n;i++)
(ans-=1LL*(n-i)*(1LL*m*(m-1)%MOD*INV2%MOD)%MOD)%=MOD;
ans=2LL*ans%MOD;
(ans+=(1LL*n*C(m,3)%MOD+1LL*m*C(n,3)%MOD)%MOD)%MOD;
printf("%d\n",ans<0?ans+MOD:ans);
return 0;
}