「ZJOI2019」麻将(dp套dp)
Address
Solution
记 \(ans_i\) 表示在给定的 \(13\) 张牌以外,再选出 \(i\) 张牌,使得这 \(13+i\) 张牌不存在 胡的子集的方案数,那么答案就是 \((\frac{1}{(4n-13)!}\sum_{i=1}^{4n-13}i!(4n-13-i)!ans_i)+1\)
接下来,考虑给你一个牌的集合,怎么判断它是否存在一个胡的子集。
首先判断胡的第二个条件:记 \(a_i\) 表示集合中有多少张第 \(i\) 种牌。若 \(\sum [a_i\ge 2]\ge 7\),则存在胡的子集。
再判断第一个条件:考虑 \(dp\)。记 \(f_{i,j,k}\) 表示考虑前 \(i\) 种牌,拿走 \(j\) 对 \((i-1,i)\),拿走 \(k\) 个 \(i\),剩下的牌最多能组成多少个面子。(注意 \(j\) 个 \((i-1,i)\),\(k\) 个 \(i\) 拿出来必须跟后面的牌组成面子)特殊地,\(f_{i,j,k}=-1\) 表示不存在这种状态。
记 \(g_{i,j,k}\) 表示考虑前 \(i\) 种牌,拿走 \(j\) 对 \((i-1,i)\),拿走 \(k\) 个 \(i\),再拿走一个对子,剩下的牌最多能组成多少个面子。
考虑到 \(3\) 个相同的顺子(形如 \(x,x+1,x+2\))可以变成 \(3\) 个相同的刻子 (形如 \(x,x,x\)),因此 \(j,k\in[0,2]\)。
记集合中最大的牌为 \(m\),如果存在 \(g_{m,j,k}\ge 4\),那么存在胡的子集。
考虑转移,枚举加入 \(x\) 张大小为 \(i+1\) 的牌,枚举拿走 \(h\) 张 \(i+1\),那么要组成 \(k\) 对 \((i,i+1)\),组成 \(j\) 对 \((i-1,i,i+1)\),再枚举要不要拿走 \(i+1\) 当对子,有:
考虑建一个自动机,自动机上的每一个节点对应一些不存在胡的子集的集合。每个节点都记录信息:\(f_{m,j,k},g_{m,j,k},cnt\)。\(f_{m,j,k},g_{m,j,k},cnt\) 都相同的集合对应同一个节点,注意 \(m\) 可以不同,所以只要记 \(f_{j,k},g_{j,k},cnt\)。节点之间的转移边权 \(x\) 表示加入 \(x\) 张大小为 \(m+1\) 的牌。
初始节点:\(f_{0,0}=cnt=0\),其它为 \(-1\)。考虑用 \(dfs\) 构造自动机,枚举加入 \(x(x∈[0,4])\) 张新牌转移即可。转移可能成环,扩展出重复状态要剪枝。\(dfs\) 后可得节点数为 \(2091\)。
记 \(ch_{x,y}\) 表示节点 \(x\) 走转移边 \(y\) 到达的节点,\(dp_{i,j,k}\) 表示考虑前 \(i\) 种牌,总共取走 \(j\) 张,目前走到自动机上的节点 \(k\) 的方案数。枚举第 \(i\) 张牌取了 \(h\) 张,有:$$dp_{i,j+h,ch_{k,h}}+=dp_{i-1,j,k}*c_{4-b_i}^{h-b_i}$$
其中 \(b_i\) 表示给定的 \(13\) 张牌中,有多少张大小为 \(i\) 的牌。
\(dp\) 要使用滚动数组,时间复杂度 \(O(2091×n^2)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
const int e = 505, o = 3005, mod = 998244353;
struct point
{
int cnt, f[3][3], g[3][3];
inline bool check()
{
if (cnt >= 7) return 1;
for (int i = 0; i <= 2; i++)
for (int j = 0; j <= 2; j++)
if (g[i][j] >= 4) return 1;
return 0;
}
inline point trans(int x)
{
point a;
int i, j, k;
a.cnt = min(7, cnt + (x >= 2));
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
a.f[i][j] = a.g[i][j] = -1;
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
{
if (f[i][j] != -1)
{
for (k = 0; i + j + k <= x && k <= 2; k++)
a.f[j][k] = max(a.f[j][k], f[i][j] + i + (x - i - j - k >= 3));
for (k = 0; i + j + k <= x - 2; k++)
a.g[j][k] = max(a.g[j][k], f[i][j] + i);
}
if (g[i][j] != -1)
{
for (k = 0; i + j + k <= x && k <= 2; k++)
a.g[j][k] = max(a.g[j][k], g[i][j] + i + (x - i - j - k >= 3));
}
}
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
a.f[i][j] = min(a.f[i][j], 4), a.g[i][j] = min(a.g[i][j], 4);
return a;
}
};
inline bool operator < (point a, point b)
{
if (a.cnt != b.cnt) return a.cnt < b.cnt;
int i, j;
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
{
if (a.f[i][j] != b.f[i][j]) return a.f[i][j] < b.f[i][j];
if (a.g[i][j] != b.g[i][j]) return a.g[i][j] < b.g[i][j];
}
return 0;
}
inline bool operator == (point a, point b)
{
if (a.cnt != b.cnt) return 0;
int i, j;
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
{
if (a.f[i][j] != b.f[i][j]) return 0;
if (a.g[i][j] != b.g[i][j]) return 0;
}
return 1;
}
map<point, int> id;
int cnt, fac[e], inv[e], ch[o][6], n, m, a[e], dp[2][e][o], ans;
inline void dfs(point a)
{
int x = id[a];
for (int i = 0; i <= 4; i++)
{
point b = a.trans(i);
if (b.check()) continue;
int y = id[b];
if (y) ch[x][i] = y;
else
{
id[b] = ++cnt;
ch[x][i] = cnt;
dfs(b);
}
}
}
inline int ksm(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1) res = (ll)res * x % mod;
y >>= 1;
x = (ll)x * x % mod;
}
return res;
}
inline void init()
{
point s;
s.cnt = 0;
int i, j;
for (i = 0; i <= 2; i++)
for (j = 0; j <= 2; j++)
s.f[i][j] = s.g[i][j] = -1;
s.f[0][0] = 0;
id[s] = cnt = 1;
dfs(s);
}
inline void add(int &x, int y)
{
(x += y) >= mod && (x -= mod);
}
inline int c(int x, int y)
{
return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod;
}
inline void prepare()
{
int i;
fac[0] = 1;
for (i = 1; i <= m; i++) fac[i] = (ll)fac[i - 1] * i % mod;
inv[m] = ksm(fac[m], mod - 2);
for (i = m - 1; i >= 0; i--) inv[i] = (ll)inv[i + 1] * (i + 1) % mod;
}
int main()
{
freopen("mahjong.in", "r", stdin);
freopen("mahjong.out", "w", stdout);
read(n); m = n << 2;
int i, j, k, h, x, y, sum = 0;
init(); prepare();
for (i = 1; i <= 13; i++) read(x), read(y), a[x]++;
dp[0][0][1] = 1;
for (i = 1; i <= n; i++)
{
int nxt = i & 1, lst = nxt ^ 1;
for (j = 0; j <= sum + 4; j++)
for (k = 1; k <= cnt; k++)
dp[nxt][j][k] = 0;
for (j = 0; j <= sum; j++)
for (k = 1; k <= cnt; k++)
if (dp[lst][j][k])
{
int v = dp[lst][j][k];
for (h = a[i]; h <= 4; h++)
if (ch[k][h])
dp[nxt][j + h][ch[k][h]] = (dp[nxt][j + h][ch[k][h]] + (ll)v
* c(4 - a[i], h - a[i])) % mod;
}
sum += 4;
}
for (i = 1; i <= m - 13; i++)
{
sum = 0;
for (j = 1; j <= cnt; j++) add(sum, dp[n & 1][13 + i][j]);
ans = (ans + (ll)sum * fac[i] % mod * fac[m - 13 - i]) % mod;
}
ans = (ll)ans * inv[m - 13] % mod;
add(ans, 1);
cout << ans << endl;
fclose(stdin);
fclose(stdout);
return 0;
}