YC317C [ 20240708 CQYC省选模拟赛 T3 ] 划分树(partition_tree)
题意
给定一棵 \(n\) 个点的树,你需要对于给定的 \(0 \le i \le k\),计算有多少删边方案,满足删除的边数为 \(i\) 且各连通块编号连续。
\(n \le 3 \times 10 ^ 5, k \le 500\)。
Sol
显然边数等价于连通块数。
不难思考一个 \(\text{dp}\),设 \(f_{i, j}\) 表示前 \(i\) 个点,组成了 \(j\) 个连通块。
简单转移一下,每次考虑一个 \(k \to i\) 的区间,若满足条件就从 \(f_{k - 1, j - 1} \to f_{i, j}\)。
这个 \(\text{dp}\) 是 \(O(n ^ 2 k)\) 的,十分炸裂。
考虑扫描线,每一次增加一个新节点 \(i\)。
不难发现其实满足 \([k, i]\) 的节点编号等价于使得 \([k, i]\) 的导出子图边数等于 \(i - k\)。
于是直接对于 \(i\) 的所有连边,从后往前在线段树上区间加。而求答案直接对于每个点加上 \(l\),然后求所有最小值之和就行了。
这样复杂度变为 \(O(nk \log n)\)。
你发现先枚举 \(k\) 然后再扫描线太傻逼了,把顺序调一下。
现在变成对于每个点维护 \(k\) 个信息,然后找一个可以转移的区间合并一下信息就可以了?
可以转移的区间根本不连续,似乎没法做了。
考虑一个转化,设区间 \([l, r]\) 的不合法度为 \(\sum_{u = l} ^ r [fa_u < l] + [fa_u > r]\),显然,合法区间满足当且仅当不合法度为 \(1\)。
注意到后面那个东西已经可以做了,随便用个队列搞一下,因为 \(r\) 是单调递增的,所以删掉的元素一定不会在后面有用,直接求出 \(i\) 前面第一个以及第二个 \(fa_j > i\) 的即可。
思考一下,发现对于一个 \(i\),前面这坨有贡献的位置只有当 \(fa_i < i\) 时,\((fa_i, i]\) 中的下标会有不合法度为 \(1\) 的贡献,而且这个东西和当前右端点无关。
因此我们需要分别维护 \(\sum_{u = l} ^ r [fa_u < l]\) 为 \(0\) 和 \(1\) 的数据结构,需要支持后缀加 \(1\),其实就是删除并向更高一级的数据结构插入就行了。
这样所有的操作都在末尾进行,可以直接维护前缀和。
查询需要注意,对于 \(\sum_{u = l} ^ r [fa_u > r]\) 的 \(0 / 1\) 段在数据结构上查询,下标数组有序,直接二分查找左右端点即可。
Code
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <stack>
#include <vector>
#include <assert.h>
using namespace std;
#ifdef ONLINE_JUDGE
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;
#endif
int read() {
int p = 0, flg = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') flg = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
p = p * 10 + c - '0';
c = getchar();
}
return p * flg;
}
void write(int x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x > 9) {
write(x / 10);
}
putchar(x % 10 + '0');
}
bool _stmer;
int debug = 0;
const int N = 3e5 + 5, M = 305, mod = 998244353;
namespace G {
array <int, N> fir;
array <int, N * 2> nex, to;
int cnt = 1;
void add(int x, int y) {
cnt++;
nex[cnt] = fir[x];
to[cnt] = y;
fir[x] = cnt;
}
} //namespace G
array <int, N> fa;
void dfs(int x) {
for (int i = G::fir[x]; i; i = G::nex[i]) {
if (G::to[i] == fa[x]) continue;
fa[G::to[i]] = x, dfs(G::to[i]);
}
}
void inc(int &x, int y) { x += y; if (x >= mod) x -= mod; }
void _inc(int &x, int y) { x -= y; if (x < 0) x += mod; }
array <array <array <int, M>, N>, 2> isl;
vector <int> arc0, arc1;
void add(array <int, M> &tp, int pl) {
int n = pl ? arc1.size() : arc0.size();
for (int i = 0; i <= 301; i++) {
isl[pl][n][i] = tp[i];
inc(isl[pl][n][i], isl[pl][n - 1][i]);
assert(n - 1 >= 0);
}
}
void del(int x, int pl) {
if (!pl) {
for (int i = 0; i <= 301; i++)
_inc(isl[pl][x][i], isl[pl][x - 1][i]);
add(isl[pl][x], pl ^ 1);
for (int i = 0; i <= 301; i++)
inc(isl[pl][x][i], isl[pl][x - 1][i]);
}
}
void query0(int l, int r, array <int, M> &tp) {
if (l > r || !arc0.size()) return;
if (r > arc0.back()) r = arc0.back();
l = lower_bound(arc0.begin(), arc0.end(), l) - arc0.begin();
r = upper_bound(arc0.begin(), arc0.end(), r) - arc0.begin();
for (int i = 0; i < 301; i++) inc(tp[i + 1], isl[0][r][i]);
for (int i = 0; i < 301; i++) _inc(tp[i + 1], isl[0][l][i]);
}
void query1(int l, int r, array <int, M> &tp) {
if (l > r || !arc1.size()) return;
if (r > arc1.back()) r = arc1.back();
l = lower_bound(arc1.begin(), arc1.end(), l) - arc1.begin();
r = upper_bound(arc1.begin(), arc1.end(), r) - arc1.begin();
for (int i = 0; i < 301; i++) inc(tp[i + 1], isl[1][r][i]);
for (int i = 0; i < 301; i++) _inc(tp[i + 1], isl[1][l][i]);
}
stack <int> stk;
bool _edmer;
signed main() {
cerr << (&_stmer - &_edmer) / 1024.0 / 1024.0 << "MB\n";
int n = read(), k = read();
for (int i = 2, x, y; i <= n; i++)
x = read(), y = read(), G::add(x, y), G::add(y, x);
dfs(1);
array <int, M> lst;
lst.fill(0), lst[0] = 1;
for (int i = 1; i <= n; i++) {
int tp1 = 0, tp2 = 0;
stk.push(i);
while (stk.size() && fa[stk.top()] <= i) stk.pop();
if (stk.size()) tp1 = stk.top(), stk.pop();
while (stk.size() && fa[stk.top()] <= i) stk.pop();
if (stk.size()) tp2 = stk.top(), stk.pop();
arc0.push_back(i), add(lst, 0);
if (fa[i] < i) {
int res = arc1.size();
while (res > 0 && arc1[res - 1] > fa[i]) res--;
for (int j = res; j < (int)arc1.size(); j++) del(j + 1, 1);
while ((int)arc1.size() > res) arc1.pop_back();
res = arc0.size();
while (res > 0 && arc0[res - 1] > fa[i]) res--;
for (int j = res; j < (int)arc0.size(); j++)
arc1.push_back(arc0[j]), del(j + 1, 0);
while ((int)arc0.size() > res) arc0.pop_back();
}
if (tp2) stk.push(tp2);
if (tp1) stk.push(tp1);
lst.fill(0);
if (tp1) query0(tp2 + 1, tp1, lst);
query1(tp1 + 1, i, lst);
}
for (int i = 1; i <= k + 1; i++)
write(lst[i]), puts("");
return 0;
}