省选训练赛 #9 题目 E 补题记录
题意:一张 \(n\times m\) 的网格图,行和列的间距为 \(1\)。有 \(n\times m\) 个激光器,每个激光器可以用 \((X_1, X_2, X_3, X_4)\) 表示,其中 \(0\le X_1, X_2, X_3, X_4\le 1\),表示是否向上、向右、向下、向左发射激光,每道激光长度为 \(0.5\)。给定每种激光器的数量,求随机摆放这些激光器时,四个角都是激光器、四条边平行于行和列且都被激光覆盖的正方形数量期望值。
\(n\times m\le 10^6\)
先转化为所有方案中正方形数量总和,但是这里同种激光器之间视为两两不同不同,即所有 \(n\times m\) 个激光器带标号计数。
这样做的好处是可以避免后面的多重集排列计数,改为直接算普通排列数。
枚举四个角使用的激光器种类,每个角使用的共有 \(4\) 种,枚举量为 \(4^4 = 256\)。
然后枚举正方形边长,有 \(\mathcal O(\sqrt {nm})\) 种可能。把剩下的激光器分为三类,只能填横边的、只能填竖边的、两种边都能填的。
枚举 两种边都能填的 给 只能填横边的 补了多少个激光器,用排列数和组合数算一下即可,这里显然是 \(\mathcal O(\sqrt {nm})\)。
总时间复杂度 \(\mathcal O(256nm)\)。
点击查看代码
#include <bits/stdc++.h>
namespace Initial {
#define ll long long
#define ull unsigned long long
#define fi first
#define se second
#define mkp make_pair
#define pir pair <ll, ll>
#define pb push_back
#define i128 __int128
using namespace std;
const ll maxn = 2e6 + 10, inf = 1e18, mod = 1e9 + 7;
ll power(ll a, ll b = mod - 2) {
ll s = 1;
while(b) {
if(b & 1) s = 1ll * s * a %mod;
a = 1ll * a * a %mod, b >>= 1;
} return s;
}
template <class T>
const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
template <class T>
const inline void chkmin(T &x, const T y) { x = x > y? y : x; }
} using namespace Initial;
namespace Read {
char buf[1 << 22], *p1, *p2;
//#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
template <class T>
const inline void rd(T &x) {
char ch; bool neg = 0;
while(!isdigit(ch = getchar()))
if(ch == '-') neg = 1;
x = ch - '0';
while(isdigit(ch = getchar()))
x = (x << 1) + (x << 3) + ch - '0';
if(neg) x = -x;
}
} using Read::rd;
ll t, n, m, a[16], b[16], ans, fac[maxn], ifac[maxn];
unordered_map <ll, ll> mp;
ll C(ll a, ll b) { return a < b || b < 0? 0 : fac[a] * ifac[b] %mod * ifac[a - b] %mod; }
ll A(ll a, ll b) { return a < b || b < 0? 0 : fac[a] * ifac[a - b] %mod; }
void solve() {
rd(n), rd(m); mp.clear(), ans = 0;
fac[0] = 1;
for(ll i = 1; i <= n * m; i++) fac[i] = fac[i - 1] * i %mod;
ifac[n * m] = power(fac[n * m]);
for(ll i = n * m; i; i--) ifac[i - 1] = ifac[i] * i %mod;
for(ll i = 0; i < 16; i++) rd(a[i]);
for(ll p = 0; p < 16; p++)
if(a[p] && ((p & 6) == 6)) {
--a[p];
for(ll q = 0; q < 16; q++)
if(a[q] && ((q & 12) == 12)) {
--a[q];
for(ll r = 0; r < 16; r++)
if(a[r] && ((r & 3) == 3)) {
--a[r];
for(ll s = 0; s < 16; s++)
if(a[s] && ((s & 9) == 9)) {
--a[s];
array <ll, 4> arr = {p, q, r, s};
sort(arr.begin(), arr.end());
ll state = arr[0] * 16 * 16 * 16
+ arr[1] * 16 * 16 + arr[2] * 16 + arr[3];
if(mp.count(state)) { add(ans, mp[state]), ++a[s]; continue; }
++b[p], ++b[q], ++b[r], ++b[s]; ll res = 0;
ll X = 0, Y = 0, Z = a[15];
for(ll w = 0; w < 15; w++) {
if((w & 5) == 5) X += a[w];
if((w & 10) == 10) Y += a[w];
}
for(ll i = 0; i < n - 1 && i < m - 1; i++)
for(ll j = max(0ll, 2 * i - X); j <= 2 * i && j <= Z; j++) {
res = (res + A(Z, j) * C(i << 1, j) %mod * A(X, (i << 1) - j)
%mod * A(Y + Z - j, i << 1) %mod * fac[n * m - 4 * (i + 1)]
%mod * (n - 1 - i) %mod * (m - 1 - i)) %mod;
}
for(ll i = 0; i < 16; i++) res = res * A(a[i] + b[i], b[i]) %mod;
--b[p], --b[q], --b[r], --b[s];
add(ans, res), mp[state] = res;
++a[s];
} ++a[r];
} ++a[q];
} ++a[p];
}
printf("%lld\n", ans * ifac[n * m] %mod);
}
int main() {
rd(t); while(t--) solve();
return 0;
}