P6773 [NOI2020] 命运

P6773 [NOI2020] 命运

考虑树形 DP,套路是维护子树方案数。

定义 (u,v)(u,v)vvuu 的祖先。

注意到存在性质:对于满足 depv1<depv2dep_{v1} < dep_{v2} 的限制 (u,v1),(u,v2)(u,v1),(u,v2),满足第二个限制即可全部满足,即对于下端点 uu 的所有限制,满足上端点最深的限制即可全部满足。

fu,if_{u,i} 表示 uu 的子树内,两端点在子树内的限制已满足,未被满足的下端点在子树内并且上端点最深为 ii 的方案数,答案就是 f1,0f_{1,0}

考虑转移,将儿子子树往父亲子树合并,讨论边的取值。

fu,ij=0depufu,ifv,j+j=0ifu,ifv,j+j=0i1fu,jfv,if_{u,i}\leftarrow \sum_{j=0}^{dep_u}f_{u,i}f_{v,j}+\sum_{j=0}^{i}f_{u,i}f_{v,j}+\sum_{j=0}^{i-1}f_{u,j}f_{v,i}

套路记录前缀和 gu,i=j=0ifu,ig_{u,i}=\sum_{j=0}^{i}f_{u,i}

fu,ifu,i(gv,depu+gv,i)+gu,i1fv,if_{u,i}\leftarrow f_{u,i}(g_{v,dep_u}+g_{v,i})+g_{u,i-1}f_{v,i}

相乘再相加,系数会变,考虑对每个点维护一棵深度线段树,状态较少,转移时使用线段树合并即可。

时间复杂度 O(nlogn)\mathcal O(n\log n)

#include <cstdio>
#include <vector>
typedef long long ll;

#define ha putchar(' ')
#define he putchar('\n')

inline int read() {
	int x = 0, f = 1;
	char c = getchar();
	while (c < '0' || c > '9') {
		if (c == '-')
			f = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9')
		x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
	return x * f;
}

inline void write(int x) {
	if (x < 0) {
		putchar('-');
		x = -x;
	}
	if (x > 9)
		write(x / 10);
	putchar(x % 10 + 48);
}

const int _ = 5e5 + 10, mod = 998244353;

int n, m, cnt, dep[_], ls[_ << 5], rs[_ << 5], rt[_];

std::vector<int> d[_], lim[_];

ll f[_ << 5], tag[_ << 5];

void pushup(int o) {
	f[o] = f[ls[o]] + f[rs[o]];
	if (f[o] >= mod) f[o] %= mod;
}

void pushdown(int o) {
	if (tag[o] != 1) {
		f[ls[o]] *= tag[o];
		if (f[ls[o]] >= mod) f[ls[o]] %= mod;
		f[rs[o]] *= tag[o];
		if (f[rs[o]] >= mod) f[rs[o]] %= mod;
		tag[ls[o]] *= tag[o];
		if (tag[ls[o]] >= mod) tag[ls[o]] %= mod;
		tag[rs[o]] *= tag[o];
		if (tag[rs[o]] >= mod) tag[rs[o]] %= mod;
		tag[o] = 1;
	}
}

void upd(int &o, int l, int r, int pos, ll v) {
	!o ? tag[o = ++cnt] = 1 : 1;
	if (l == r) return tag[o] = 1, f[o] = v, void();
	pushdown(o);
	int mid = (l + r) >> 1;
	pos <= mid ? upd(ls[o], l, mid, pos, v) : upd(rs[o], mid + 1, r, pos, v);
	pushup(o);
}

ll qry(int o, int l, int r, int L, int R) {
	if (L <= l && r <= R) return f[o];
	pushdown(o);
	int mid = (l + r) >> 1;
	ll res = 0;
	if (L <= mid) res = qry(ls[o], l, mid, L, R);
	if (R > mid) res += qry(rs[o], mid + 1, r, L, R);
	if (res >= mod) res %= mod;
	return res;
}

int mge(int x, int y, int l, int r, int &k1, int &k2) {
	if (!x && !y) return 0;
	if (!x) {
		k2 += f[y], tag[y] *= k1, f[y] *= k1;
		if (k2 >= mod) k2 %= mod;
		if (tag[y] >= mod) tag[y] %= mod;
		if (f[y] >= mod) f[y] %= mod;
		return y;
	}
	if (!y) {
		k1 += f[x], tag[x] *= k2, f[x] *= k2;
		if (k1 >= mod) k1 %= mod;
		if (tag[x] >= mod) tag[x] %= mod;
		if (f[x] >= mod) f[x] %= mod;
		return x;
	}
	if (l == r) {
		ll fx = f[x];
		k2 += f[y], f[x] = (f[x] * k2 + f[y] * k1), k1 += fx;
		if (k2 >= mod) k2 %= mod;
		if (k1 >= mod) k1 %= mod;
		if (f[x] >= mod) f[x] %= mod;
		return x;
	}
	pushdown(x), pushdown(y);
	int mid = (l + r) >> 1;
	ls[x] = mge(ls[x], ls[y], l, mid, k1, k2);
	rs[x] = mge(rs[x], rs[y], mid + 1, r, k1, k2);
	pushup(x);
	return x;
}

void dfs(int u, int fa) {
	dep[u] = dep[fa] + 1;
	int mx = 0, k1, k2;
	for (int v : lim[u]) mx = std::max(mx, dep[v]);
	upd(rt[u], 0, n, mx, 1);
	for (int v : d[u])
		if (v != fa) {
			dfs(v, u);
			k1 = 0, k2 = qry(rt[v], 0, n, 0, dep[u]);
			rt[u] = mge(rt[u], rt[v], 0, n, k1, k2);
		}
}

signed main() {
	int u, v;
	n = read();
	for (int i = 1; i < n; ++i) {
		u = read(), v = read();
		d[u].emplace_back(v), d[v].emplace_back(u);
	}
	m = read();
	for (int i = 1; i <= m; ++i) {
		u = read(), v = read();
		lim[v].emplace_back(u);
	}
	dfs(1, 0);
	write(qry(rt[1], 0, n, 0, 0)), he;
	return 0;
}
posted @ 2022-08-21 17:23  蒟蒻orz  阅读(4)  评论(0编辑  收藏  举报  来源