[长链剖分优化dp] Codeforces 1499F

题目大意

给定一棵 \(n(2\leq n\leq 5000)\) 个点的树,求一共有多少种方案,删去若干条边后,分裂出的所有树的直径都不超过 \(K\),答案模 \(998244353\)

题解

\(dp[u][i]\) 表示把以 \(u\) 为根的子树分裂成若干棵直径不超过 \(K\) 的树,且以 \(u\) 为根的树的高度为 \(i\) 的方案数。令 \(buf[u][i]\) 表示转移后的 \(dp[u][i]\),则每遇到一个 \(u\) 的孩子 \(v\) 就有:

\[buf[u][i]+=dp[u][i]\times\sum_{j=0}^{k}dp[v][j](切断u-v这条边)\\ buf[u][max(i,j+1)]+=dp[u][i]\times dp[v][j] (i+j+1\leq K)(不切断u-v这条边) \]

\(v\) 的所有转移完成后再把 \(buf\) 赋给 \(dp\),以免重复。但是这样的时间复杂度是 \(O(N^2K)\)
观察一下dp式子,发现第二维的下标和深度有关,于是我们可以使用长链剖分优化树形dp。

回想一下我们设的dp状态,\(dp[v][i]\) 表示将 \(v\) 这棵子树删边,并且删边后以 \(v\) 为根的子树的高度为 \(i\) 的方案数。若当前 \(u\) 只有 \(v\) 这一个儿子,那么使得以 \(v\) 为根的子树高度为 \(i\) 的方案数就相当于使得以 \(u\) 为根的子树高度为 \(i+1\) 并且不删掉 \(u-v\) 这条边的方案数,我们可以维护一个指向dp数组第一维的指针,来实现 \(u\) \(O(1)\) 继承 \(v\) 的dp。若 \(u\) 存在多个儿子,显然这个 \(v\) 应该是重儿子(长链上的儿子) 最优。于是我们 \(O(1)\) 继承重儿子的dp值,对于轻儿子的贡献我们暴力转移。可以发现发生暴力转移的轻儿子都是其各自所在重链的顶端,暴力转移只转移到每条链的长度。所以暴力转移带来的时间复杂度和所有重链的长度之和同阶,而长链剖分后所有重链的长度之和为 \(N\),我们还要取转移 \(dp\) 的第二维,所以长链剖分优化后dp的时间复杂度降为 \(O(NK)\)

Code

#include <bits/stdc++.h>
using namespace std;

#define RG register int
#define LL long long

template<typename elemType>
inline void Read(elemType& T) {
    elemType X = 0, w = 0; char ch = 0;
    while (!isdigit(ch)) { w |= ch == '-'; ch = getchar(); }
    while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
    T = (w ? -X : X);
}

const int maxn = 5010;
const LL MOD = 998244353LL;

struct Graph {
    struct edge { int Next, to; };
    edge G[maxn << 1];
    int head[maxn];
    int cnt;

    Graph() :cnt(2) {}
    void clear(int n) {
        cnt = 2; fill(head, head + n + 2, 0);
    }
    void add_edge(int u, int v) {
        G[cnt].to = v;
        G[cnt].Next = head[u];
        head[u] = cnt++;
    }
};

Graph G;
int Height[maxn], Hson[maxn];
LL temp[maxn << 2], buf[maxn], * dp[maxn], * id = temp;
int N, K;

void DFS_Init(int u, int fa) {
    Height[u] = 0;
    for (int i = G.head[u]; i; i = G.G[i].Next) {
        int v = G.G[i].to;
        if (v == fa) continue;
        DFS_Init(v, u);
        if (Height[v] > Height[Hson[u]]) Hson[u] = v;
        Height[u] = max(Height[u], Height[v] + 1);
    }
}

LL DFS(int u, int fa) {
    dp[u][0] = 1;
    if (Hson[u]) {
        dp[Hson[u]] = dp[u] + 1;
        LL sum = DFS(Hson[u], u);
        dp[u][0] = sum;
    }
    for (int i = G.head[u]; i; i = G.G[i].Next) {
        int v = G.G[i].to;
        if (v == fa || v == Hson[u]) continue;
        dp[v] = id; id += Height[v] + 1;
        for (int l = 0; l <= min(Height[v], K); ++l) dp[v][l] = 0;
        DFS(v, u);
        for (int l = 0; l <= min(Height[u], K); ++l) buf[l] = 0;
        for (int l = 0; l <= min(Height[u], K); ++l) {
            for (int j = 0; j <= min(Height[v], K); ++j) {
                if (j <= K) buf[l] = (buf[l] + dp[u][l] * dp[v][j] % MOD) % MOD;
                if (l + j + 1 <= K) buf[max(l, j + 1)] = (buf[max(l, j + 1)] + dp[u][l] * dp[v][j] % MOD) % MOD;
            }
        }
        for (int l = 0; l <= min(Height[u], K); ++l) dp[u][l] = buf[l];
    }
    if (Hson[fa] == u) {
        LL sum = 0;
        for (int i = 0; i <= min(Height[u], K); ++i) sum = (sum + dp[u][i]) % MOD;
        return sum;
    }
    return 0;
}

int main() {
    Read(N); Read(K);
    for (int i = 1; i <= N - 1; ++i) {
        int u, v;
        Read(u); Read(v);
        G.add_edge(u, v);
        G.add_edge(v, u);
    }
    DFS_Init(1, 0);
    dp[1] = id; id += Height[1] + 1;
    DFS(1, 0);
    LL ans = 0;
    for (int i = 0; i <= min(Height[1], K); ++i)
        ans = (ans + dp[1][i]) % MOD;
    printf("%I64d\n", ans);

    return 0;
}
posted @ 2021-04-01 18:25  AE酱  阅读(153)  评论(0编辑  收藏  举报