[长链剖分优化dp] Codeforces 1499F
题目大意
给定一棵 \(n(2\leq n\leq 5000)\) 个点的树,求一共有多少种方案,删去若干条边后,分裂出的所有树的直径都不超过 \(K\),答案模 \(998244353\)。
题解
设 \(dp[u][i]\) 表示把以 \(u\) 为根的子树分裂成若干棵直径不超过 \(K\) 的树,且以 \(u\) 为根的树的高度为 \(i\) 的方案数。令 \(buf[u][i]\) 表示转移后的 \(dp[u][i]\),则每遇到一个 \(u\) 的孩子 \(v\) 就有:
\(v\) 的所有转移完成后再把 \(buf\) 赋给 \(dp\),以免重复。但是这样的时间复杂度是 \(O(N^2K)\)。
观察一下dp式子,发现第二维的下标和深度有关,于是我们可以使用长链剖分优化树形dp。
回想一下我们设的dp状态,\(dp[v][i]\) 表示将 \(v\) 这棵子树删边,并且删边后以 \(v\) 为根的子树的高度为 \(i\) 的方案数。若当前 \(u\) 只有 \(v\) 这一个儿子,那么使得以 \(v\) 为根的子树高度为 \(i\) 的方案数就相当于使得以 \(u\) 为根的子树高度为 \(i+1\) 并且不删掉 \(u-v\) 这条边的方案数,我们可以维护一个指向dp数组第一维的指针,来实现 \(u\) \(O(1)\) 继承 \(v\) 的dp。若 \(u\) 存在多个儿子,显然这个 \(v\) 应该是重儿子(长链上的儿子) 最优。于是我们 \(O(1)\) 继承重儿子的dp值,对于轻儿子的贡献我们暴力转移。可以发现发生暴力转移的轻儿子都是其各自所在重链的顶端,暴力转移只转移到每条链的长度。所以暴力转移带来的时间复杂度和所有重链的长度之和同阶,而长链剖分后所有重链的长度之和为 \(N\),我们还要取转移 \(dp\) 的第二维,所以长链剖分优化后dp的时间复杂度降为 \(O(NK)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define RG register int
#define LL long long
template<typename elemType>
inline void Read(elemType& T) {
elemType X = 0, w = 0; char ch = 0;
while (!isdigit(ch)) { w |= ch == '-'; ch = getchar(); }
while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
T = (w ? -X : X);
}
const int maxn = 5010;
const LL MOD = 998244353LL;
struct Graph {
struct edge { int Next, to; };
edge G[maxn << 1];
int head[maxn];
int cnt;
Graph() :cnt(2) {}
void clear(int n) {
cnt = 2; fill(head, head + n + 2, 0);
}
void add_edge(int u, int v) {
G[cnt].to = v;
G[cnt].Next = head[u];
head[u] = cnt++;
}
};
Graph G;
int Height[maxn], Hson[maxn];
LL temp[maxn << 2], buf[maxn], * dp[maxn], * id = temp;
int N, K;
void DFS_Init(int u, int fa) {
Height[u] = 0;
for (int i = G.head[u]; i; i = G.G[i].Next) {
int v = G.G[i].to;
if (v == fa) continue;
DFS_Init(v, u);
if (Height[v] > Height[Hson[u]]) Hson[u] = v;
Height[u] = max(Height[u], Height[v] + 1);
}
}
LL DFS(int u, int fa) {
dp[u][0] = 1;
if (Hson[u]) {
dp[Hson[u]] = dp[u] + 1;
LL sum = DFS(Hson[u], u);
dp[u][0] = sum;
}
for (int i = G.head[u]; i; i = G.G[i].Next) {
int v = G.G[i].to;
if (v == fa || v == Hson[u]) continue;
dp[v] = id; id += Height[v] + 1;
for (int l = 0; l <= min(Height[v], K); ++l) dp[v][l] = 0;
DFS(v, u);
for (int l = 0; l <= min(Height[u], K); ++l) buf[l] = 0;
for (int l = 0; l <= min(Height[u], K); ++l) {
for (int j = 0; j <= min(Height[v], K); ++j) {
if (j <= K) buf[l] = (buf[l] + dp[u][l] * dp[v][j] % MOD) % MOD;
if (l + j + 1 <= K) buf[max(l, j + 1)] = (buf[max(l, j + 1)] + dp[u][l] * dp[v][j] % MOD) % MOD;
}
}
for (int l = 0; l <= min(Height[u], K); ++l) dp[u][l] = buf[l];
}
if (Hson[fa] == u) {
LL sum = 0;
for (int i = 0; i <= min(Height[u], K); ++i) sum = (sum + dp[u][i]) % MOD;
return sum;
}
return 0;
}
int main() {
Read(N); Read(K);
for (int i = 1; i <= N - 1; ++i) {
int u, v;
Read(u); Read(v);
G.add_edge(u, v);
G.add_edge(v, u);
}
DFS_Init(1, 0);
dp[1] = id; id += Height[1] + 1;
DFS(1, 0);
LL ans = 0;
for (int i = 0; i <= min(Height[1], K); ++i)
ans = (ans + dp[1][i]) % MOD;
printf("%I64d\n", ans);
return 0;
}