loj #3119. 「CTS2019 | CTSC2019」随机立方体

简单二项式反演吧

先构造一个向量\(ans,ans(i)\)表示k=i时的答案

\(ans\)右乘一个杨辉三角得到一个向量\(f\),如果可以计算出\(f\)

那么我们可以在\(O(n)\)的时间通过乘杨辉三角逆矩阵计算出\(ans(k)\)

这个技巧也被称之为二项式反演

(下划线被latex吞的有点不明显,式子中的幂均是下降幂)

现在考虑如何计算\(f(i)\)

为了方便起见我们令\(N=nml\)

将每一种方案按照以下方式拆分,然后使用乘法原理计数

1.钦定i个位置作为极大位置

这一部分的方案数是\(n^{\underline i}m^{\underline i}l^{\underline i}\)

2.钦定\(g(i)\)个值给极大值和被极大值支配的位置,

其中\(g(i)\)表示有i个极大值时被极大值支配的位置总数

\(g(i)=N-(n-i)(m-i)(l-i)\)

这一部分的方案数是\({N \choose g(i)}\)

3.给这\(g(i)\)个位置分配大小关系,使得钦定的值全部是极大值

设这一部分的方案数是\(h(i)\),如何计算\(h(i)\)一会再说

4.剩下的位置随便放置,这一部分的方案数是\((N-g(i))!\)

所以我们可以得到这个式子

\[f(i)=n^{\underline i}m^{\underline i}l^{\underline i}{N \choose g(i)}(N-g(i))!h(i) \]

现在我们考虑把每一个被极大值支配的点以这种方式建一颗树出来

1.如果它不是极大值,那么向和它在同一个面上最小的极大值连边

2.否则这个位置是极大值,向比它小的极大位置中最大的极大位置连边

这样可以搞出一颗树来,容易看出只要钦定的大小关系满足树的拓扑序,就是一个合法的方案

直接套树的拓扑序方案数可以得出\(h(i)=\frac{g(i)!}{\prod_{i=1}^{n}g(i)}\)

所以稍微化简一下就是

\[f(i)=\frac{n^{\underline i}m^{\underline i}l^{\underline i}}{\prod_{i=1}^{n}g(i)} \]

通过离线求出\(O(n)\)求出g的逆元就能\(O(n)\)的计算\(f\)

然后稍微反演一下就可以输出答案了

#include<cstdio>
#include<algorithm>
using namespace std;typedef long long ll;
const ll mod=998244353;const int N=5*1e6+10;
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
ll fac[N];ll ifac[N];
inline void prew()
{
	fac[0]=1;
	for(int i=1;i<N;i++)
		fac[i]=fac[i-1]*i%mod;
	ifac[0]=1;ifac[1]=1;
	for(int i=2;i<N;i++)
		ifac[i]=(mod-mod/i)*ifac[mod%i]%mod;
	for(int i=1;i<N;i++)
		(ifac[i]*=ifac[i-1])%=mod;
	//for(int i=1;i<=10;i++)printf("%lld ",fac[i]);printf("\n");
	//for(int i=1;i<=10;i++)printf("%lld ",ifac[i]);printf("\n");
}
inline ll c(ll n,ll m)
{
	//ll
	return fac[n]*ifac[m]%mod*ifac[n-m]%mod;
	//printf("ret=%lld\n",ret);return ret;
}
inline ll dpo(ll n,ll p)
{
	return fac[n]*ifac[n-p]%mod;
}
struct data
{
	ll val;int cnt;
	inline void giv(){val=po(val,mod-2);cnt=-cnt;}
	data (ll a=1,ll b=0)
	{
		val=a;cnt=b;
	}
	friend data operator *(data a,data b)
	{
		return data((a.val*b.val)%mod,a.cnt+b.cnt);
	}
	inline void ih(ll sv=1)
	{
		cnt=(sv==0);val=sv+cnt;
	}
	inline ll gval(){return val*(cnt==0);}
}pre[N],suf[N];
ll f[N];ll g[N];int k;
int n;int m;int l;
inline void solve()
{
	scanf("%d%d%d%d",&n,&m,&l,&k);
	int mi=min(min(n,m),l);
	if(k>mi){printf("0\n");return;}
	for(int i=1;i<=mi;i++)
		g[i]=((ll)n*m%mod*l%mod+(mod-n+i)*(m-i)%mod*(l-i)%mod)%mod;
	//for(int i=1;i<=mi;i++)
	//	printf("%lld ",g[i]);printf("\n");
	for(int i=1;i<=mi;i++)
		pre[i].ih(g[i]),suf[i].ih(g[i]);
	for(int i=1;i<=mi;i++)
		pre[i]=pre[i]*pre[i-1];
	for(int i=mi;i>=1;i--)
		suf[i]=suf[i]*suf[i+1];
	data iv=pre[mi];iv.giv();
	for(int i=1;i<=mi;i++)
		g[i]=(iv*pre[i-1]*suf[i+1]).gval();
	g[0]=1;
	for(int i=1;i<=mi;i++)
		(g[i]*=g[i-1])%=mod;
	//for(int i=1;i<=mi;i++)
	//	printf("%lld ",g[i]*56%mod);printf("\n");
	for(int i=1;i<=mi;i++)
		f[i]=dpo(n,i)*dpo(m,i)%mod*dpo(l,i)%mod*g[i]%mod;
//	for(int i=1;i<=mi;i++)
//		printf("%lld ",f[i]);printf("\n");
	ll res=0;
	for(int i=k,tp=0;i<=mi;i++,tp^=1)
		if(tp==0)(res+=c(i,k)*f[i])%=mod;
		else (res+=(mod-c(i,k))*f[i])%=mod;
	printf("%lld\n",res);
	for(int i=0;i<=mi+1;i++)pre[i].ih();
	for(int i=0;i<=mi+1;i++)suf[i].ih();
}
int main()
{
	prew();
	int T;scanf("%d",&T);
	for(int i=1;i<=T;i++)
		solve();
	return 0;
}
posted @ 2019-06-06 17:28  sweetphoenix  阅读(161)  评论(0编辑  收藏  举报