「解题报告」P9167 [省选联考 2023] 城市建造

考场降智系列。

首先手模一下,发现题意就是让我们选出一个点集,删去之间的所有边,使得这个点集中的每个点都不在同一个连通块中。首先可以发现选择的点集一定是若干个点双,如果一个点双中的点没有全部被选择,那么一定有其中至少两个点仍然在同一连通块中。其次,我们发现选取的点集一定是联通的,否则同样会有两个点在同一连通块中。不难发现这就是点集合法的充要条件。

那么问题转换成了:给定一个图,删去若干个联通的点双,使得剩下的连通块的大小的极差 \(\le k\)

首先点双容易想到圆方树转成树,现在就是一个树形 DP 了。极差不好做,我们考虑直接枚举连通块的大小。容易发现,当 \(k=0\) 时,这个连通块的大小只能是 \(n\) 的因子,数量是 \(O(\sqrt{n})\) 的;\(k=1\) 时,假如有 \(k\) 个连通块,那么每个连通块的大小为 \(\lfloor \frac{n}{k}\rfloor\)\(\lceil \frac{n}{k}\rceil\),而这个连通块大小也只有 \(O(\sqrt{n})\) 种。

假如我们枚举当前的连通块大小为 \(t\)(与 \(t+1\)),那么我们直接在圆方树上跑树形背包,就很容易能够 DP 求出答案。具体的,对于每一个方点,它的方案为所有儿子的方案数的乘积;对于每一个圆点,我们需要让当前圆点所在的连通块大小为 \(t\)(与 \(t+1\)),那么我们可以选择一些儿子保留,选择一些儿子删除,这个可以通过树形背包实现。

对于选取的连通块,我们在深度最小的位置统计答案,那么在每个点统计它作为选取的连通块的深度最小的点时的方案数即可。这样的复杂度为 \(O(nt)\)。假如我们对每一个大小做一遍,由于调和级数为 \(O(\log n)\),那么总复杂度为 \(O(n^2 \log n)\),能够拿到 65 分。

以上是考场上想到的做法,但是考场降智,忘了怎么建圆方树了,痛失 40 分。

我们可以更仔细的分析一下这个背包的选取过程。先考虑 \(k=0\) 的情况。我们发现一件事情,就是对于一个子树,如果我们选取它,那么这个子树的大小至少要为 \(t\),否则内部不可能划分出大小为 \(t\) 的连通块。而且,如果子树大小 \(\ge t\),那么我们一定不能不选,因为如果不选,根节点所在的连通块大小至少为 \(t+1\),也不符合题意。这样,我们可以按照子树大小与 \(t\) 的关系,唯一确定下来子树的选取方案,然后判断一下形成的连通块大小是否等于 \(t\) 即可。这样就能优化到单次 \(O(n)\) 了。

\(k=1\) 的情况同样可以类似的考虑,我们发现,子树大小 \(<t\) 时同样一定不选,子树大小 \(>t\) 时同样一定选,而对于 \(=t\) 的情况,我们可以有一个不选,这样根节点所在连通块大小为 \(t+1\),同样符合限制。那么,我们就可以枚举哪个 \(=t\) 的子树不选,然后计算答案即可,这样复杂度也是 \(O(n)\) 的,总复杂度就是 \(O(n \sqrt{n})\) 的了。

据说这个做法卡卡常可以过,但是我自己卡了一上午,仍然只能做到最大的点本机跑 0.8s,luogu 上 TLE。我还是比较菜。

上述做法感觉上其实很多节点都不存在合法的划分方案。我们考虑不枚举连通块大小分别跑 DP,而是对每个连通块大小一起跑 DP。分析上述算法过程,我们发现,每次选取的子树都是一个后缀(中间可能删去一个树)。那么我们可以先枚举选取的这个后缀,然后计算其对应的连通块大小,只对这些连通块大小进行 DP。由于前缀的个数是 \(O(deg)\) 的,这样每个圆点的合法的连通块大小就只有 \(O(deg)\) 种,总共的状态数其实就只有 \(O(n)\) 个。对于方点来说,它的方案数等于所有子树的方案数的乘积,那么它的合法状态数也一定是小于等于所有子树中状态数的最小值,所以方点的状态数也是 \(O(n)\) 的。于是这样跑 DP 就能 \(O(n)\) 得出答案了。细节很多,我写了好久才写出来,我好菜。

视实现而定,可做到 \(O(n \log n)\)\(O(n)\) 的复杂度。我写的 \(O(n \log n)\),复杂度瓶颈在给子树排序。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005, P = 998244353;
int n, m, k;
int ans[MAXN][2];
set<int> s;
unordered_map<int, int> f[MAXN][2], fsuf[MAXN];
struct Tree {
    vector<int> e[MAXN], re[MAXN];
    int tot;
    void add(int u, int v) {
        e[u].push_back(v), e[v].push_back(u);
    }
    int siz[MAXN];
    int to[MAXN];
    void dfs(int u, int pre) {
        siz[u] = u <= n;
        for (int v : e[u]) if (v != pre) {
            dfs(v, u);
            siz[u] += siz[v];
        }
        sort(e[u].begin(), e[u].end(), [&](int a, int b) { return siz[a] < siz[b]; });
        int deg = 0;
        for (int v : e[u]) if (v != pre) {
            deg++;
            to[deg] = v;
        }
        if (u <= n) {
            { // k = 0
                for (int i = 1; i <= deg; i++) 
                    fsuf[i].clear();
                for (int i = deg; i >= 1; i--) {
                    int v = to[i];
                    if (i == deg) fsuf[i] = f[v][0];
                    else for (auto p : f[v][0]) {
                        int t = p.first, w = p.second;
                        fsuf[i][t] = 1ll * fsuf[i + 1][t] * w % P;
                    }
                }
                int sz = 1, up = n - siz[u];
                for (int i = 1; i <= deg; i++) {
                    if (siz[to[i]] >= sz) {
                        if (s.count(sz)) f[u][0][sz] = (f[u][0][sz] + fsuf[i][sz]) % P;
                        if (s.count(sz + up)) ans[sz + up][0] = (ans[sz + up][0] + fsuf[i][sz + up]) % P;
                    }
                    sz += siz[to[i]];
                }
                if (s.count(sz)) f[u][0][sz] = (f[u][0][sz] + 1) % P;
                if (s.count(sz + up)) ans[sz + up][0] = (ans[sz + up][0] + 1) % P;
            }
            { // k = 1
                for (int i = 1; i <= deg; i++) 
                    fsuf[i].clear();
                for (int i = deg; i >= 1; i--) {
                    int v = to[i];
                    if (i == deg) fsuf[i] = f[v][1];
                    else for (auto p : f[v][1]) {
                        int t = p.first, w = p.second;
                        fsuf[i][t] = 1ll * fsuf[i + 1][t] * w % P;
                    }
                }
                int sz = 1, up = n - siz[u];
                for (int i = 1; i <= deg; i++) {
                    if (siz[to[i]] >= sz) {
                        if (s.count(sz)) f[u][1][sz] = (f[u][1][sz] + fsuf[i][sz]) % P;
                    }
                    if (siz[to[i]] >= sz + up) {
                        if (s.count(sz + up)) {
                            ans[sz + up][1] = (ans[sz + up][1] + fsuf[i][sz + up]) % P;
                        }
                    }
                    if (sz - 1 > 0 && siz[to[i - 1]] < sz - 1 && siz[to[i]] >= sz - 1) {
                        if (s.count(sz - 1)) f[u][1][sz - 1] = (f[u][1][sz - 1] + fsuf[i][sz - 1]) % P;
                    }
                    if (sz + up - 1 > 0 && siz[to[i - 1]] < sz + up - 1 && siz[to[i]] >= sz + up - 1) {
                        if (s.count(sz + up - 1)) {
                            ans[sz + up - 1][1] = (ans[sz + up - 1][1] + fsuf[i][sz + up - 1]) % P;
                        }
                    }
                    sz += siz[to[i]];
                }
                if (s.count(sz)) f[u][1][sz] = (f[u][1][sz] + 1) % P;
                if (deg && siz[to[deg]] < sz - 1) {
                    if (s.count(sz - 1)) f[u][1][sz - 1] = (f[u][1][sz - 1] + 1) % P;
                }
                if (deg && siz[to[deg]] < sz + up - 1) {
                    if (s.count(sz + up - 1)) ans[sz + up - 1][1] = (ans[sz + up - 1][1] + 1) % P;
                }
                if (deg) {
                    int kk = 0, prod = 1;
                    int lst = 0;
                    sz = siz[to[1]];
                    for (int i = 1; i <= deg; i++) {
                        int v = to[i];
                        if (siz[v] != sz) {
                            lst = i;
                            break;
                        }
                        kk = (prod + 1ll * kk * f[v][1][sz]) % P;
                        prod = 1ll * prod * f[v][1][sz] % P;
                    }
                    if (lst) kk = 1ll * kk * fsuf[lst][sz] % P;
                    if (s.count(sz)) f[u][1][sz] = (f[u][1][sz] + kk) % P;
                    if (u == 1 && s.count(sz)) {
                        ans[sz][1] = (ans[sz][1] + kk) % P;
                    }
                }
            }
        } else {
            int vv = 0;
            for (int v : e[u]) if (v != pre) {
                vv = v;
                break;
            }
            if (!vv) {
                for (int t : s) f[u][0][t] = f[u][1][t] = 1;
            } else {
                for (int k = 0; k <= 1; k++) {
                    for (auto p : f[vv][k]) {
                        int t = p.first;
                        f[u][k][t] = 1;
                        for (int v : e[u]) if (v != pre) {
                            f[u][k][t] = 1ll * f[u][k][t] * f[v][k][t] % P;
                        }
                    }
                }
            }
        }
    }
} t;
struct Graph {
    vector<int> e[MAXN];
    void add(int u, int v) {
        e[u].push_back(v), e[v].push_back(u);
    }
    int low[MAXN], dfn[MAXN], dcnt;
    stack<int> s;
    void tarjan(int u) {
        low[u] = dfn[u] = ++dcnt;
        s.push(u);
        for (int v : e[u]) if (!dfn[v]) {
            tarjan(v);
            low[u] = min(low[u], low[v]);
            if (low[v] >= dfn[u]) {
                int id = ++t.tot;
                while (1) {
                    int w = s.top(); s.pop();
                    t.add(id, w);
                    if (w == v) break;
                }
                t.add(id, u);
            }
        } else {
            low[u] = min(low[u], dfn[v]);
        }
    }
} g;
int main() {
    // freopen("cities.in", "r", stdin);
    // freopen("cities.out", "w", stdout);
    scanf("%d%d%d", &n, &m, &k);
    for (int i = 1; i <= m; i++) {
        int u, v; scanf("%d%d", &u, &v);
        g.add(u, v);
    }
    t.tot = n;
    g.tarjan(1);
    for (int i = 2; i <= n; i++) 
        s.insert(n / i);
    t.dfs(1, 0);
    int fans = 0;
    if (k == 0) {
        for (int i : s) if (n % i == 0) {
            fans = (fans + ans[i][0]) % P;
        }
    } else {
        for (int i : s) {
            fans = (fans + ans[i][1]) % P;
        }
        for (int i : s) if (s.count(i - 1)) {
            fans = (fans - ans[i][0] + P) % P;
        }
    }
    printf("%d\n", fans);
    return 0;
}
posted @ 2023-04-07 17:01  APJifengc  阅读(219)  评论(0编辑  收藏  举报