YC317C [ 20240708 CQYC省选模拟赛 T3 ] 划分树(partition_tree)

题意

给定一棵 \(n\) 个点的树,你需要对于给定的 \(0 \le i \le k\),计算有多少删边方案,满足删除的边数为 \(i\) 且各连通块编号连续。

\(n \le 3 \times 10 ^ 5, k \le 500\)

Sol

显然边数等价于连通块数。

不难思考一个 \(\text{dp}\),设 \(f_{i, j}\) 表示前 \(i\) 个点,组成了 \(j\) 个连通块。

简单转移一下,每次考虑一个 \(k \to i\) 的区间,若满足条件就从 \(f_{k - 1, j - 1} \to f_{i, j}\)

这个 \(\text{dp}\)\(O(n ^ 2 k)\) 的,十分炸裂。

考虑扫描线,每一次增加一个新节点 \(i\)

不难发现其实满足 \([k, i]\) 的节点编号等价于使得 \([k, i]\) 的导出子图边数等于 \(i - k\)

于是直接对于 \(i\) 的所有连边,从后往前在线段树上区间加。而求答案直接对于每个点加上 \(l\),然后求所有最小值之和就行了。

这样复杂度变为 \(O(nk \log n)\)

你发现先枚举 \(k\) 然后再扫描线太傻逼了,把顺序调一下。

现在变成对于每个点维护 \(k\) 个信息,然后找一个可以转移的区间合并一下信息就可以了?

可以转移的区间根本不连续,似乎没法做了。

考虑一个转化,设区间 \([l, r]\) 的不合法度为 \(\sum_{u = l} ^ r [fa_u < l] + [fa_u > r]\),显然,合法区间满足当且仅当不合法度为 \(1\)

注意到后面那个东西已经可以做了,随便用个队列搞一下,因为 \(r\) 是单调递增的,所以删掉的元素一定不会在后面有用,直接求出 \(i\) 前面第一个以及第二个 \(fa_j > i\) 的即可。

思考一下,发现对于一个 \(i\),前面这坨有贡献的位置只有当 \(fa_i < i\) 时,\((fa_i, i]\) 中的下标会有不合法度为 \(1\) 的贡献,而且这个东西和当前右端点无关。

因此我们需要分别维护 \(\sum_{u = l} ^ r [fa_u < l]\)\(0\)\(1\) 的数据结构,需要支持后缀加 \(1\),其实就是删除并向更高一级的数据结构插入就行了。

这样所有的操作都在末尾进行,可以直接维护前缀和。

查询需要注意,对于 \(\sum_{u = l} ^ r [fa_u > r]\)\(0 / 1\) 段在数据结构上查询,下标数组有序,直接二分查找左右端点即可。

Code

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <stack>
#include <vector>
#include <assert.h>
using namespace std;
#ifdef ONLINE_JUDGE

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;

#endif
int read() {
    int p = 0, flg = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') flg = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        p = p * 10 + c - '0';
        c = getchar();
    }
    return p * flg;
}
void write(int x) {
    if (x < 0) {
        x = -x;
        putchar('-');
    }
    if (x > 9) {
        write(x / 10);
    }
    putchar(x % 10 + '0');
}
bool _stmer;

int debug = 0;

const int N = 3e5 + 5, M = 305, mod = 998244353;

namespace G {

array <int, N> fir;
array <int, N * 2> nex, to;

int cnt = 1;
void add(int x, int y) {
    cnt++;
    nex[cnt] = fir[x];
    to[cnt] = y;
    fir[x] = cnt;
}

} //namespace G

array <int, N> fa;

void dfs(int x) {
    for (int i = G::fir[x]; i; i = G::nex[i]) {
        if (G::to[i] == fa[x]) continue;
        fa[G::to[i]] = x, dfs(G::to[i]);
    }
}

void inc(int &x, int y) { x += y; if (x >= mod) x -= mod; }

void _inc(int &x, int y) { x -= y; if (x < 0) x += mod; }

array <array <array <int, M>, N>, 2> isl;

vector <int> arc0, arc1;

void add(array <int, M> &tp, int pl) {
    int n = pl ? arc1.size() : arc0.size();

    for (int i = 0; i <= 301; i++) {
        isl[pl][n][i] = tp[i];
        inc(isl[pl][n][i], isl[pl][n - 1][i]);
        assert(n - 1 >= 0);
    }
}

void del(int x, int pl) {
    if (!pl) {
        for (int i = 0; i <= 301; i++)
            _inc(isl[pl][x][i], isl[pl][x - 1][i]);
        add(isl[pl][x], pl ^ 1);
        for (int i = 0; i <= 301; i++)
            inc(isl[pl][x][i], isl[pl][x - 1][i]);
    }
}

void query0(int l, int r, array <int, M> &tp) {
    if (l > r || !arc0.size()) return;
    if (r > arc0.back()) r = arc0.back();
    l = lower_bound(arc0.begin(), arc0.end(), l) - arc0.begin();
    r = upper_bound(arc0.begin(), arc0.end(), r) - arc0.begin();
    for (int i = 0; i < 301; i++) inc(tp[i + 1], isl[0][r][i]);
    for (int i = 0; i < 301; i++) _inc(tp[i + 1], isl[0][l][i]);
}

void query1(int l, int r, array <int, M> &tp) {
    if (l > r || !arc1.size()) return;
    if (r > arc1.back()) r = arc1.back();
    l = lower_bound(arc1.begin(), arc1.end(), l) - arc1.begin();
    r = upper_bound(arc1.begin(), arc1.end(), r) - arc1.begin();
    for (int i = 0; i < 301; i++) inc(tp[i + 1], isl[1][r][i]);
    for (int i = 0; i < 301; i++) _inc(tp[i + 1], isl[1][l][i]);
}

stack <int> stk;

bool _edmer;
signed main() {
    cerr << (&_stmer - &_edmer) / 1024.0 / 1024.0 << "MB\n";
    int n = read(), k = read();
    for (int i = 2, x, y; i <= n; i++)
        x = read(), y = read(), G::add(x, y), G::add(y, x);
    dfs(1);
    array <int, M> lst;
    lst.fill(0), lst[0] = 1;
    for (int i = 1; i <= n; i++) {
        int tp1 = 0, tp2 = 0;
        stk.push(i);
        while (stk.size() && fa[stk.top()] <= i) stk.pop();
        if (stk.size()) tp1 = stk.top(), stk.pop();
        while (stk.size() && fa[stk.top()] <= i) stk.pop();
        if (stk.size()) tp2 = stk.top(), stk.pop();
        arc0.push_back(i), add(lst, 0);


        if (fa[i] < i) {
            int res = arc1.size();
            while (res > 0 && arc1[res - 1] > fa[i]) res--;
            for (int j = res; j < (int)arc1.size(); j++) del(j + 1, 1);
            while ((int)arc1.size() > res) arc1.pop_back();

            res = arc0.size();
            while (res > 0 && arc0[res - 1] > fa[i]) res--;
            for (int j = res; j < (int)arc0.size(); j++)
                arc1.push_back(arc0[j]), del(j + 1, 0);
            while ((int)arc0.size() > res) arc0.pop_back();
        }

        if (tp2) stk.push(tp2);
        if (tp1) stk.push(tp1);

        lst.fill(0);
        if (tp1) query0(tp2 + 1, tp1, lst);
        query1(tp1 + 1, i, lst);

    }
    for (int i = 1; i <= k + 1; i++)
        write(lst[i]), puts("");
    return 0;
}
posted @ 2024-07-21 11:07  cxqghzj  阅读(13)  评论(0编辑  收藏  举报