[LOJ3119][CTS2019|CTSC2019]随机立方体:组合数学+二项式反演

分析

感觉这道题的计数方法好厉害。。

一个直观的思路是,把题目转化为求至少有\(k\)个极大的数的概率。

考虑这样一个事实,如果钦定\((1,1,1),(2,2,2),...,(k,k,k)\)是那\(k\)个极大值的位置,并且\(val(1,1,1) < val(2,2,2) < ... < val(k,k,k)\)。我们考虑依次确定这些值,显然\(val(1,1,1)\)的值是和它至少有一维相同的\(n \times m \times l - (n-1) \times (m-1) \times (l-1)\)个位置中最大的一个,\(val(2,2,2)\)的值是和它至少有一维相同的所有位置并上\((1,1,1)\)限制到的所有位置中最大的一个,即\(n \times m \times l - (n-2) \times (m-2) \times (l-2)\)个位置中最大的一个。

以此类推,\(val(i,i,i)\)的值是\(n \times m \times l - (n-i) \times (m-i) \times (l-i)\)个位置中最大的一个。我们记\(cnt(i) = n \times m \times l - (n-i) \times (m-i) \times (l-i)\),进而可以得到\((i,i,i)\)是极大值的概率是\(P(i) = \frac{1}{cnt(i)}\)

现在我们所需要的就是计算\(P(i)\)的前缀和,这个可以通过线性处理逆元的技巧完成,然后二项式反演,式子如下,其中\(A_n^m\)表示排列数:

\[ans = \sum_{i=k}^{n}(-1)^{i-k}\binom{i}{k}A_n^iA_m^iA_l^i\prod_{j=1}^{i}P(j) \]

关于概率为什么能二项式反演?

可以这样理解:概率再乘上个阶乘就是方案数了。

yyb聚聚的题解

戳这里

分析的方式不太一样,不过本质和最后得到的结果是相同的。

代码

#include <bits/stdc++.h>

#define rin(i,a,b) for(int i=(a);i<=(b);++i)
#define irin(i,a,b) for(int i=(a);i>=(b);--i)
#define trav(i,a) for(int i=head[a];i;i=e[i].nxt)
#define Size(a) (int) a.size()
#define pb push_back
#define mkpr std::make_pair
#define fi first
#define se second
#define lowbit(a) ((a)&(-(a)))
typedef long long LL;

using std::cerr;
using std::endl;

inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}

const int MAXN=5000005;
const int MOD=998244353;

int n,m,l,k;
int fac[MAXN],invf[MAXN];
int cnt[MAXN],fix[MAXN];

inline int qpow(int x,int y){
	int ret=1,tt=x%MOD;
	while(y){
		if(y&1)ret=1ll*ret*tt%MOD;
		tt=1ll*tt*tt%MOD;
		y>>=1;
	}
	return ret;
}

inline int C(int n,int m){
	if(n<0||m<0||n<m)return 0;
	return 1ll*fac[n]*invf[n-m]%MOD*invf[m]%MOD;
}

inline int A(int n,int m){
	if(n<0||m<0||n<m)return 0;
	return 1ll*fac[n]*invf[n-m]%MOD;
}

void init(int n){
	fac[0]=1;
	rin(i,1,n)fac[i]=1ll*fac[i-1]*i%MOD;
	invf[n]=qpow(fac[n],MOD-2);
	irin(i,n-1,0)invf[i]=1ll*invf[i+1]*(i+1)%MOD;
}

int main(){
	init(5000000);
	int T=read();
	while(T--){
		int inp[4];
		rin(i,1,3)inp[i]=read();
		std::sort(inp+1,inp+4);
		n=inp[1],m=inp[2],l=inp[3];
		k=read();
		int tot=1;
		rin(i,1,n){
			cnt[i]=(1ll*n*m%MOD*l%MOD-1ll*(n-i)*(m-i)%MOD*(l-i)%MOD+MOD)%MOD;
			tot=1ll*tot*cnt[i]%MOD;
		}
		fix[n]=qpow(tot,MOD-2);
		irin(i,n-1,1)fix[i]=1ll*fix[i+1]*cnt[i+1]%MOD;
		int ans=0,sgn=MOD-1;
		rin(i,k,n){
			sgn=MOD-sgn;
			ans=(ans+1ll*sgn*C(i,k)%MOD*A(n,i)%MOD*A(m,i)%MOD*A(l,i)%MOD*fix[i])%MOD;
		}
		printf("%d\n",ans);
	}
	return 0;
}

posted on 2019-05-21 22:30  ErkkiErkko  阅读(155)  评论(0编辑  收藏  举报