题解-CTS2019随机立方体
problem
题意概要:一个 \(n\times m\times l\) 的立方体,立方体中每个格子上都有一个数,如果某个格子上的数比三维坐标中至少有一维相同的其他格子上的数都要大的话,我们就称它是极大的。将 \(n\times m\times l\) 的排列随机填入这些格子,求恰有 \(k\) 个极大的数的概率。\(T\) 组数据。
\(T\le 10,\ 1\le n,m,l\le 5\times 10^6,\ 1\le k \le 100\),时限 \(5s\)
Solution
为啥CTS比APIO难这多?我果然还是不会数数呐
根据这一个月来的数数经验……发现:
- 求答案为 \(k\) (权值)的题:应该是转换成 “答案小于等于 \(k\)” 减去 “答案小于等于 \(k-1\)”。
- 求恰好 \(k\) 个(数量)的题:应该是转换成 “至少有 \(k\) 个”,再利用二项式反演得到答案。
所以这里设 \(f[i]\) 表示至少有 \(i\) 个极大值的概率,答案即为:
问题转换成求 \(f[i]\)。
首先选出这 \(i\) 个确定的极大值的坐标,系数为:\(P_n^iP_m^iP_l^i\)(考虑顺序,后边要用到)
定义一个坐标 \((x_0,y_0,z_0)\) 的控制范围为 三个平面 \(x=x_0,y=y_0,z=z_0\) 的并(所以极大值的定义即为:该点权值 为 其控制范围上点的权值最大值)
问题转化为求这 \(i\) 个坐标的控制范围的并内,有多少种安排数字顺序(不是权值而是大小关系,因为目前已经选出了这些坐标,并且现在要求的是概率,而非方案数)的方法,使得对于每个选定的坐标,其都为自己控制范围内的最大值。
直接考虑这个问题不好考虑,需要找到突破口,而这里的突破口就是 “这些控制范围的并内,最大值一定是一个极大值”。由于我们在选出这 \(i\) 个坐标的时候,已经考虑了顺序问题,所以若 这\(i\)个极大值控制范围的并 大小为 \(S_i\),则这件事情发生的概率为 \(\frac 1{S_i}\),并且发生这件事情后,所有只被这个极大值控制的点都没用了,则将问题转化为选出 \(i-1\) 个坐标的情况,如此递归,计算得到的权值为 \(\prod_{j=1}^i\frac 1{S_j}\)
至于如何计算 \(S_i\),可以发现取全局减多余可得 \(S_i=nml-(n-i)(m-i)(l-i)\),或是直接考虑容斥 \(S_i=(nm+nl+ml)i-(n+m+l)i^2+i^3\)
汇总一下,得到:
这个式子是线性的,但由于后头那个 \(\prod\) 需要求逆元,所以复杂度为 \(O(\min\{n,m,l\}\log p)\),如此能过 \(80\)。
考虑到这个东西实际上是要求每一个前缀积的逆元,和求阶乘逆元类似,可以先求出整个前缀积的逆元,再从后面往前乘,复杂度 \(O(\min\{n,m,l\}+\log p)\)
设:
\[a_i=nml-(n-i)(m-i)(l-i)\\ b_i=\prod_{j=1}^ia_j\\ c_i=\frac 1{b_i} \]可以 \(O(1)\) 得到 \(a_i\),\(O(n)\) 得到 \(b_i\),\(O(\log p)\) 得到 \(c_n=\frac 1{b_n}\),\(O(n)\) 得到 \(c_i=c_{i+1}\cdot a_{i+1}\)
Code
//loj-3119
#include <cstdio>
typedef long long ll;
const int N = 5001010, p = 998244353;
int fac[N], ifac[N];
int coe[N], h[N], ih[N];
int n, m, l, k;
inline int qpow(int A, int B) {
int res = 1; while(B) {
if(B&1) res = (ll)res * A%p;
A = (ll)A * A%p, B >>= 1;
} return res;
}
int main() {
fac[0] = 1;
for(int i=1;i<N;++i) fac[i] = (ll)fac[i-1] * i%p;
ifac[N-1] = qpow(fac[N-1], p-2);
for(int i=N-1;i;--i) ifac[i-1] = (ll)ifac[i] * i%p;
int T; scanf("%d",&T);
while(T--) {
scanf("%d%d%d%d", &n, &m, &l, &k);
if(n > m) n ^= m, m ^= n, n ^= m;
if(n > l) n ^= l, l ^= n, n ^= l;
const int s2 = ((ll)n*m + (ll)m*l + (ll)n*l)%p, s1 = n+m+l;
for(int i=h[0]=1;i<=n;++i) {
coe[i] = (s2 - (ll)s1 * i + (ll)i*i + (ll)p*p)%p * i%p;
h[i] = (ll)h[i-1] * coe[i]%p;
}
ih[n] = qpow(h[n], p-2);
for(int i=n;i;--i) ih[i-1] = (ll)ih[i] * coe[i]%p;
const int nml_k = (ll)fac[n] * fac[m]%p * fac[l]%p * ifac[k]%p;
int Ans = 0;
for(int i=k;i<=n;++i) {
int vl = (ll)nml_k * fac[i]%p * ifac[i-k]%p * ifac[n-i]%p * ifac[m-i]%p * ifac[l-i]%p;
vl = (ll)vl * ih[i]%p;
if(i-k&1) vl = p - vl;
(Ans += vl) >= p && (Ans -= p);
}
printf("%d\n", Ans);
}
return 0;
}