[CTS2019] 随机立方体
题目描述
有一个 \(n\times m\times l\) 的立方体,立方体中每个格子上都有一个数,如果某个格子上的数比三维坐标至少有一维相同的其他格子上的数都要大的话,我们就称它是极大的。
现在将 \(1\sim n\times m\times l\) 这 \(n\times m\times l\) 个数等概率随机填入 \(n\times m\times l\) 个格子(即任意数字出现在任意格子上的概率均相等),使得每个数恰出现一次,求恰有 \(k\) 个极大的数的概率。答案对 \(998244353\)(一个质数)取模。
\(\textbf{Data Range:} 1 \leq n, m, l \leq 5\times 10^5, 1 \leq k \leq 100, 1 \leq T \leq 10\)。
要求的是至少 \(k\) 个,下意识二项式反演。
设 \(f_i\) 表示恰好 \(i\) 个极大数的时候的概率。
设 \(g_i\) 表示至少有 \(i\) 个极大数的时候的概率。
那么答案就为 \(f_k\)。
于是我们只用求 \(g\) 就可以求出答案了。
以下考虑求 \(g_k\)。
首先考虑选出 \(k\) 个点的方案数,对于每个维度都是一个下降幂。
我们这里选择的点是有序的。
那么自然为 \(P_{n}^k \times P_{m}^k \times P_{l}^k\)。
你可以先从二维平面角度进行考虑。
从小到大进行考虑,假设现在是第 \(i\) 点,前 \(i - 1\) 的极大值确实放在了前 \(i-1\) 个位置。
现在是第 \(i\) 个位置。
注意到一个极大值是第 \(i\) 小,意味着它是前 \(i\) 个位置所占的截面的并上的最大值。
思考一下,这东西构成了一个树形关系,而方案数就等价于树形拓扑序的一种求法。
我们假设一条对角线上下来的就是顺序摆放。
我们考虑第一个点他能支配的位置就是对应的第一列和第一行,其实也就是 \(nm - (n - 1) \times (m - 1)\)。
第二个点呢?因为我们是顺序摆放的,他需要比第一个点支配的都大,同时他又要支配他自己的第二行和第二列。
那么就是 \(nm - (n - 2)\times(m-2)\)。
以此类推,可以推测出第 \(i\) 个点。
以此类推,可以推测出三维情况。
\(g_k = P_{n}^k \times P_{m}^k \times P_{l}^k \prod_{i=1}^k\dfrac{1}{nml - (n - i)(m-i)(l-i)}\)。
后面的部分我们可以预处理一个后缀积,然后复杂度就是 \(O(Tn)\) 的了。
// 德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱
// 德丽莎的可爱在于德丽莎很可爱,德丽莎为什么很可爱呢,这是因为德丽莎很可爱!
// 没有力量的理想是戏言,没有理想的力量是空虚
#include <bits/stdc++.h>
#define LL long long
#define int long long
using namespace std;
namespace io {
const int SIZE = (1 << 19) + 1; char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55]; int f, qr;
#define gc() (iS == iT ? (iT = (iS = ibuf) + fread (ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS ++)) : *iS ++)
inline void flush () { fwrite (obuf, 1, oS - obuf, stdout); oS = obuf; }
inline void putc (char x) { *oS ++ = x; if (oS == oT) flush (); }
template <class I> inline void gi (I &x) { for (f = 1, c = gc(); c < '0' || c > '9'; c = gc()) if (c == '-') f = -1; for (x = 0; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15); x *= f; }
string getstr(void) { string s = ""; char c = gc(); while (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF) c = gc(); while (!(c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF))s.push_back(c), c = gc(); return s;}
template <class I> inline void print (I x) { if (!x) putc ('0'); if (x < 0) putc ('-'), x = -x; while (x) qu[++ qr] = x % 10 + '0', x /= 10; while (qr) putc (qu[qr --]); }
struct Flusher_ {~Flusher_(){flush();}}io_flusher_;
}
using io :: gi; using io :: putc; using io :: print;
template<class T> bool chkmin(T &a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool chkmax(T &a, T b) { return a < b ? (a = b, true) : false; }
#define rep(i, l, r) for (int i = (l); i <= (r); i++)
#define repd(i, l, r) for (int i = (l); i >= (r); i--)
const int N = 7e6, mod = 998244353;
int power(int a,int b) {
int ans = 1;
while (b) { if (b & 1) ans = ans * a % mod; b >>= 1; a = a * a % mod; }
return ans;
}
int inv(int x) { return power(x, mod - 2); }
int n, m, l, k, f[N], g[N], fac[N], ifac[N];
void init() {
fac[0] = 1;
rep (i, 1, N - 1) fac[i] = fac[i - 1] * i % mod;
ifac[N - 1] = power(fac[N - 1], mod - 2);
repd (i, N - 2, 1) ifac[i] = ifac[i + 1] * (i + 1) % mod;
ifac[0] = 1;
}
int P(int n,int m) { return fac[n] * ifac[n - m] % mod; }
int C(int n,int m) {
if (n < m) return 0;
if (!m || n == m) return 1;
return fac[n] * ifac[m] % mod* ifac[n - m] % mod;
}
int hs[N], sh[N];
void pre2(int s) {
hs[0] = 1;
rep (i, 1, s) hs[i] = hs[i - 1] * ((n * m % mod * l % mod - (n - i) % mod * (m - i) % mod * (l - i) % mod) % mod + mod) % mod;
sh[s] = inv(hs[s]);
repd (i, s - 1, 0) sh[i] = sh[i + 1] * ((n * m % mod * l % mod - (n - (i + 1)) % mod * (m - (i + 1)) % mod * (l - (i + 1)) % mod) % mod + mod) % mod;
return ;
}
void solve() {
gi(n), gi(m), gi(l), gi(k);
if (n > m) n ^= m ^= n ^= m;
if (n > l) n ^= l ^= n ^= l;
int qwq = N; chkmin(qwq, n); chkmin(qwq, m); chkmin(qwq, l);
pre2(qwq);
// cout << n << " " << m << " " << l << " " << k << " " << qwq << "\n";
int ans = 0;
// cout << k << " " << qwq
rep (i, k, qwq) {
int op = 1;
if ((i - k) % 2 == 0) op = 1;
else op = -1;
op = op * C(i, k);
op += mod; op %= mod;
// cout << P(n, i) << " " << n << " " << i << "\n";
// cout << P(n, i) << " " << P(m, i) << " " << P(l, i) << "\n";
op = op * P(n, i) % mod * P(m, i) % mod * P(l, i) % mod;
op = op * sh[i] % mod;
op += mod; op %= mod;
ans += op; ans %= mod;
}
print(ans), putc('\n');
// cout << "qwq\n";
}
signed main () {
#ifdef LOCAL_DEFINE
freopen("1.in", "r", stdin);
freopen("1.ans", "w", stdout);
#endif
init();
int T; gi(T);
while (T--) solve();
#ifdef LOCAL_DEFINE
cerr << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC << " s.\n";
#endif
return 0;
}