ARC142D Deterministic Placing
原题链接 https://atcoder.jp/contests/arc142/tasks/arc142_d
对我来说,这是一道很复杂的 \(dp\) 题,很考验基本功,也十分考验分析问题的准确性。
考场上我的大致思路已经大差不差了,但是 \(dp\) 转移的细节实在是过于冗杂,导致我经过了一个月才把这道题目搞定,其中经历了各种没有考虑周全的情况,做到我精神崩溃,而发现网上没有和我思路完全一致的博客,于是只好自己慢慢 debug。
我们发现,本题实际上是统计将一棵树划分成互不相交链的方案数。其中,每条链有一个端点是空着的,其他所有点都放有棋子,每次移动,棋子就在链的两端之间滑动。经过尝试,我们 \(dp\) 状态要分为 \(7\) 种情况:
- \(0\) 该点为不放棋子的端点,并且链向子树方向延伸
- \(1\) 该点为不放棋子的端点,并且链向父亲方向延伸
- \(2\) 该点为放棋子的端点,并且链向子树方向延伸
- \(3\) 该点为链的中段(即不为两端点),且棋子向子树方向滑动
- \(4\) 该点为放棋子的端点,并且链向父亲方向延伸
- \(5\) 该点为链的中段,且棋子向父亲方向滑动
- \(6\) 该点为链的中段,且横向跨越子树(即向两个儿子延伸)
经过仔细分析,有如下转移式:
\[\begin{aligned}
f[u][0] &= \sum_{v_1} (f[v_1][4] + f[v_1][5])\times\prod_{v\neq v_1} f[v][2]\\
f[u][1] &= \prod_{v} f[v][2]\\
f[u][2] &= \sum_{v_1} (f[v_1][1]+ f[v_1][3]) \times \prod_{v\neq v_1} f[v][0]\\
f[u][3] &= \sum_{v_1} (f[v_1][1] + f[v_1][3]) \times \prod_{v\neq v_1}f[v][6]\\
f[u][4] &= \prod_{v} f[v][0]\\
f[u][5] &= \sum_{v_1} (f[v_1][4] + f[v_1][5])\times \prod_{v\neq v_1} f[v][6]\\
f[u][6] &= \sum_{v_1 \neq v_2} (f[v_1][1]+f[v_1][3])\times(f[v_2][4] + f[v_2][5])\times \prod_{v\neq v_1, v_2} f[v][6]
\end{aligned}
\]
请仔细体会状态转移的细节。
如何 \(O(1)\) 维护转移呢?一开始我想复杂了,其实很简单,我们可以借助若干个变量来避免枚举 \(v_1, v_2\)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 200005;
const int mod = 998244353;
int n;
ll f[maxn][7];
int head[maxn], nxt[maxn << 1], tail[maxn << 1], ecnt;
inline ll mul(ll a, ll b) { return (a * b) % mod; }
inline ll add(ll a, ll b) { a += b; if (a >= mod) a -= mod; if (a < 0) a += mod; return a; }
void addedge(int u, int v) {
nxt[++ecnt] = head[u];
head[u] = ecnt;
tail[ecnt] = v;
}
ll tmp[4];
void dfs(int u, int p) {
for (int e = head[u]; e; e = nxt[e]) {
int v = tail[e];
if (v != p) dfs(v, u);
}
tmp[0] = 1; tmp[1] = 0;
for (int e = head[u]; e; e = nxt[e]) {
int v = tail[e];
if (v != p) {
tmp[1] = mul(tmp[1], f[v][2]);
tmp[1] = add(tmp[1], mul(tmp[0], add(f[v][4], f[v][5])));
tmp[0] = mul(tmp[0], f[v][2]);
}
}
f[u][0] = tmp[1];
f[u][1] = tmp[0];
tmp[0] = 1; tmp[1] = 0;
for (int e = head[u]; e; e = nxt[e]) {
int v = tail[e];
if (v != p) {
tmp[1] = mul(tmp[1], f[v][0]);
tmp[1] = add(tmp[1], mul(tmp[0], add(f[v][1], f[v][3])));
tmp[0] = mul(tmp[0], f[v][0]);
}
}
f[u][2] = tmp[1];
f[u][4] = tmp[0];
tmp[0] = 1; tmp[1] = 0;
for (int e = head[u]; e; e = nxt[e]) {
int v = tail[e];
if (v != p) {
tmp[1] = mul(tmp[1], f[v][6]);
tmp[1] = add(tmp[1], mul(tmp[0], add(f[v][1], f[v][3])));
tmp[0] = mul(tmp[0], f[v][6]);
}
}
f[u][3] = tmp[1];
tmp[0] = 1; tmp[1] = 0;
for (int e = head[u]; e; e = nxt[e]) {
int v = tail[e];
if (v != p) {
tmp[1] = mul(tmp[1], f[v][6]);
tmp[1] = add(tmp[1], mul(tmp[0], add(f[v][4], f[v][5])));
tmp[0] = mul(tmp[0], f[v][6]);
}
}
f[u][5] = tmp[1];
tmp[0] = 1; tmp[1] = tmp[2] = tmp[3] = 0;
for (int e = head[u]; e; e = nxt[e]) {
int v = tail[e];
if (v != p) {
tmp[3] = mul(tmp[3], f[v][6]);
tmp[3] = add(tmp[3], mul(tmp[1], add(f[v][4], f[v][5])));
tmp[3] = add(tmp[3], mul(tmp[2], add(f[v][1], f[v][3])));
tmp[1] = mul(tmp[1], f[v][6]);
tmp[2] = mul(tmp[2], f[v][6]);
tmp[1] = add(tmp[1], mul(tmp[0], add(f[v][1], f[v][3])));
tmp[2] = add(tmp[2], mul(tmp[0], add(f[v][4], f[v][5])));
tmp[0] = mul(tmp[0], f[v][6]);
}
}
f[u][6] = tmp[3];
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n;
for (int i = 1; i <= n - 1; i++) {
int u, v; cin >> u >> v;
addedge(u, v); addedge(v, u);
}
dfs(1, 0);
cout << add(add(f[1][0], f[1][2]), f[1][6]) << endl;
return 0;
}