【BZOJ3518】点组计数

Description

  
  平面上摆放着一个\(n*m\)的点阵(下图所示是一个3*4的点阵)。Curimit想知道有多少三点组(a,b,c)满足以a,b,c三点共线。这里a,b,c是不同的3个点,其顺序无关紧要。(即(a,b,c)和
(b,c,a)被认为是相同的)。由于答案很大,故你只需要输出答案对1,000,000,007的余数就可以了。
  
img
  
  \(1\le n,m\le50000\).
  
  
  
     

Solution

  
  首先是横着的和竖着的点组,总共有\(m{n \choose 3}+n{m\choose 3}\)种方案。
   
  接下来是斜着的点组,怎么统计?
  
  我想复杂了,想枚举A点和B点,然后作一条射线,统计射线上有多少个点......
  
  更加简洁的做法是:枚举A点和C点,然后看两点的连线经过多少个点。现在只考虑A在C的左上角的情况,A在C右上角的情况同理,只需要对答案乘2。
  
  我们枚举\(C\)相对\(A\)偏移值\(x\)\(y\),若\(A(i,j)\),则\(C(i+x,j+y)\)。这样的\(AC\)组合会出现\((n-x)(m-y)\)次。我们发现,\(AC\)连线经过的点数恰好是\(\gcd(x,y)-1\),因此左上往右下斜的答案就是:

\[\begin{aligned} S&=\sum_{x=1}^{n-1}\sum_{y=1}^{m-1}(n-x)(m-y)(\gcd(x,y)-1)\\ S+\sum_{x=1}^{n-1}\sum_{y=1}^{m-1}(n-x)(m-y)&=\sum_{x=1}^{n-1}\sum_{y=1}^{m-1}(n-x)(m-y)\gcd(x,y)\\ \end{aligned} \]

  左边的和式可以直接算,我们关心右边:这里用到了欧拉函数的性质\(\sum_{d|n}\varphi(d)=n\)

\[\begin{aligned} 右边&=\sum_{x=1}^{n-1}\sum_{y=1}^{m-1}(n-x)(m-y)\sum_{d|\gcd(x,y)}\varphi(d)\\ &=\sum_{d=1}^{\min(n-1,m-1)}\varphi(d)\sum_{i=1}^{\lfloor\frac{n-1}d\rfloor}(n-i*d)\sum_{j=1}^{\lfloor\frac{m-1}d\rfloor}(m-j*d) \end{aligned} \]

  枚举\(d\),用等差数列求和\(O(1)\)算出后面的部分,于是这一段部分也就计算完了!那么总答案就是:

\[Ans=m{n \choose 3}+n{m\choose 3}+S \]

  
  

Code

  

#include <cstdio>
using namespace std;
const int N=50005,MOD=1e9+7,INV2=500000004;
int n,m;
bool vis[N];
int p[N],pcnt,phi[N];
int fact[N],invfact[N];
inline int min(int x,int y){return x<y?x:y;}
inline int max(int x,int y){return x>y?x:y;}
inline int ksm(int x,int y){
	int res=1;
	for(;y;x=1LL*x*x%MOD,y>>=1)
		if(y&1) res=1LL*res*x%MOD;
	return res;
}
void sieve(){
	phi[1]=1;
	for(int i=2;i<N;i++){
		if(!vis[i]){
			p[++pcnt]=i;
			phi[i]=i-1;
		}
		for(int j=1;j<=pcnt&&i*p[j]<N;j++){
			int x=i*p[j];
			vis[x]=true;
			if(i%p[j]==0){
				phi[x]=phi[i]*p[j];
				break;
			}
			phi[x]=phi[i]*(p[j]-1);
		}
	}
}
inline int C(int n,int m){
	if(m<0||m>n) return 0;
	return 1LL*fact[n]*invfact[m]%MOD*invfact[n-m]%MOD;
}
inline int calc(int a,int d){
	return (1LL*((a-1)/d)*a%MOD-1LL*INV2*d%MOD*((a-1)/d+1)%MOD*((a-1)/d)%MOD)%MOD;
}
int main(){
	sieve();
	scanf("%d%d",&n,&m);
	int maxs=max(n,m),mins=min(n-1,m-1);
	fact[0]=1;
	for(int i=1;i<=maxs;i++) fact[i]=1LL*fact[i-1]*i%MOD;
	invfact[maxs]=ksm(fact[maxs],MOD-2);
	for(int i=maxs-1;i>=0;i--) invfact[i]=1LL*invfact[i+1]*(i+1)%MOD;
	int ans=0;			
	for(int d=1;d<=mins;d++)
		(ans+=1LL*phi[d]*calc(n,d)%MOD*calc(m,d)%MOD)%=MOD;
	for(int i=1;i<n;i++)	
		(ans-=1LL*(n-i)*(1LL*m*(m-1)%MOD*INV2%MOD)%MOD)%=MOD;
	ans=2LL*ans%MOD;
	(ans+=(1LL*n*C(m,3)%MOD+1LL*m*C(n,3)%MOD)%MOD)%MOD;
	printf("%d\n",ans<0?ans+MOD:ans);
	return 0;
}
posted @ 2018-06-21 17:40  RogerDTZ  阅读(119)  评论(0编辑  收藏  举报