loj #3119. 「CTS2019 | CTSC2019」随机立方体
简单二项式反演吧
先构造一个向量\(ans,ans(i)\)表示k=i时的答案
把\(ans\)右乘一个杨辉三角得到一个向量\(f\),如果可以计算出\(f\)
那么我们可以在\(O(n)\)的时间通过乘杨辉三角逆矩阵计算出\(ans(k)\)
这个技巧也被称之为二项式反演
(下划线被latex吞的有点不明显,式子中的幂均是下降幂)
现在考虑如何计算\(f(i)\)
为了方便起见我们令\(N=nml\)
将每一种方案按照以下方式拆分,然后使用乘法原理计数
1.钦定i个位置作为极大位置
这一部分的方案数是\(n^{\underline i}m^{\underline i}l^{\underline i}\)
2.钦定\(g(i)\)个值给极大值和被极大值支配的位置,
其中\(g(i)\)表示有i个极大值时被极大值支配的位置总数
\(g(i)=N-(n-i)(m-i)(l-i)\)
这一部分的方案数是\({N \choose g(i)}\)
3.给这\(g(i)\)个位置分配大小关系,使得钦定的值全部是极大值
设这一部分的方案数是\(h(i)\),如何计算\(h(i)\)一会再说
4.剩下的位置随便放置,这一部分的方案数是\((N-g(i))!\)
所以我们可以得到这个式子
现在我们考虑把每一个被极大值支配的点以这种方式建一颗树出来
1.如果它不是极大值,那么向和它在同一个面上最小的极大值连边
2.否则这个位置是极大值,向比它小的极大位置中最大的极大位置连边
这样可以搞出一颗树来,容易看出只要钦定的大小关系满足树的拓扑序,就是一个合法的方案
直接套树的拓扑序方案数可以得出\(h(i)=\frac{g(i)!}{\prod_{i=1}^{n}g(i)}\)
所以稍微化简一下就是
通过离线求出\(O(n)\)求出g的逆元就能\(O(n)\)的计算\(f\)了
然后稍微反演一下就可以输出答案了
#include<cstdio>
#include<algorithm>
using namespace std;typedef long long ll;
const ll mod=998244353;const int N=5*1e6+10;
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
ll fac[N];ll ifac[N];
inline void prew()
{
fac[0]=1;
for(int i=1;i<N;i++)
fac[i]=fac[i-1]*i%mod;
ifac[0]=1;ifac[1]=1;
for(int i=2;i<N;i++)
ifac[i]=(mod-mod/i)*ifac[mod%i]%mod;
for(int i=1;i<N;i++)
(ifac[i]*=ifac[i-1])%=mod;
//for(int i=1;i<=10;i++)printf("%lld ",fac[i]);printf("\n");
//for(int i=1;i<=10;i++)printf("%lld ",ifac[i]);printf("\n");
}
inline ll c(ll n,ll m)
{
//ll
return fac[n]*ifac[m]%mod*ifac[n-m]%mod;
//printf("ret=%lld\n",ret);return ret;
}
inline ll dpo(ll n,ll p)
{
return fac[n]*ifac[n-p]%mod;
}
struct data
{
ll val;int cnt;
inline void giv(){val=po(val,mod-2);cnt=-cnt;}
data (ll a=1,ll b=0)
{
val=a;cnt=b;
}
friend data operator *(data a,data b)
{
return data((a.val*b.val)%mod,a.cnt+b.cnt);
}
inline void ih(ll sv=1)
{
cnt=(sv==0);val=sv+cnt;
}
inline ll gval(){return val*(cnt==0);}
}pre[N],suf[N];
ll f[N];ll g[N];int k;
int n;int m;int l;
inline void solve()
{
scanf("%d%d%d%d",&n,&m,&l,&k);
int mi=min(min(n,m),l);
if(k>mi){printf("0\n");return;}
for(int i=1;i<=mi;i++)
g[i]=((ll)n*m%mod*l%mod+(mod-n+i)*(m-i)%mod*(l-i)%mod)%mod;
//for(int i=1;i<=mi;i++)
// printf("%lld ",g[i]);printf("\n");
for(int i=1;i<=mi;i++)
pre[i].ih(g[i]),suf[i].ih(g[i]);
for(int i=1;i<=mi;i++)
pre[i]=pre[i]*pre[i-1];
for(int i=mi;i>=1;i--)
suf[i]=suf[i]*suf[i+1];
data iv=pre[mi];iv.giv();
for(int i=1;i<=mi;i++)
g[i]=(iv*pre[i-1]*suf[i+1]).gval();
g[0]=1;
for(int i=1;i<=mi;i++)
(g[i]*=g[i-1])%=mod;
//for(int i=1;i<=mi;i++)
// printf("%lld ",g[i]*56%mod);printf("\n");
for(int i=1;i<=mi;i++)
f[i]=dpo(n,i)*dpo(m,i)%mod*dpo(l,i)%mod*g[i]%mod;
// for(int i=1;i<=mi;i++)
// printf("%lld ",f[i]);printf("\n");
ll res=0;
for(int i=k,tp=0;i<=mi;i++,tp^=1)
if(tp==0)(res+=c(i,k)*f[i])%=mod;
else (res+=(mod-c(i,k))*f[i])%=mod;
printf("%lld\n",res);
for(int i=0;i<=mi+1;i++)pre[i].ih();
for(int i=0;i<=mi+1;i++)suf[i].ih();
}
int main()
{
prew();
int T;scanf("%d",&T);
for(int i=1;i<=T;i++)
solve();
return 0;
}