P6773 [NOI2020] 命运
P6773 [NOI2020] 命运
考虑树形 DP
,套路是维护子树方案数。
定义 中 是 的祖先。
注意到存在性质:对于满足 的限制 ,满足第二个限制即可全部满足,即对于下端点 的所有限制,满足上端点最深的限制即可全部满足。
记 表示 的子树内,两端点在子树内的限制已满足,未被满足的下端点在子树内并且上端点最深为 的方案数,答案就是 。
考虑转移,将儿子子树往父亲子树合并,讨论边的取值。
套路记录前缀和 。
相乘再相加,系数会变,考虑对每个点维护一棵深度线段树,状态较少,转移时使用线段树合并即可。
时间复杂度 。
#include <cstdio>
#include <vector>
typedef long long ll;
#define ha putchar(' ')
#define he putchar('\n')
inline int read() {
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return x * f;
}
inline void write(int x) {
if (x < 0) {
putchar('-');
x = -x;
}
if (x > 9)
write(x / 10);
putchar(x % 10 + 48);
}
const int _ = 5e5 + 10, mod = 998244353;
int n, m, cnt, dep[_], ls[_ << 5], rs[_ << 5], rt[_];
std::vector<int> d[_], lim[_];
ll f[_ << 5], tag[_ << 5];
void pushup(int o) {
f[o] = f[ls[o]] + f[rs[o]];
if (f[o] >= mod) f[o] %= mod;
}
void pushdown(int o) {
if (tag[o] != 1) {
f[ls[o]] *= tag[o];
if (f[ls[o]] >= mod) f[ls[o]] %= mod;
f[rs[o]] *= tag[o];
if (f[rs[o]] >= mod) f[rs[o]] %= mod;
tag[ls[o]] *= tag[o];
if (tag[ls[o]] >= mod) tag[ls[o]] %= mod;
tag[rs[o]] *= tag[o];
if (tag[rs[o]] >= mod) tag[rs[o]] %= mod;
tag[o] = 1;
}
}
void upd(int &o, int l, int r, int pos, ll v) {
!o ? tag[o = ++cnt] = 1 : 1;
if (l == r) return tag[o] = 1, f[o] = v, void();
pushdown(o);
int mid = (l + r) >> 1;
pos <= mid ? upd(ls[o], l, mid, pos, v) : upd(rs[o], mid + 1, r, pos, v);
pushup(o);
}
ll qry(int o, int l, int r, int L, int R) {
if (L <= l && r <= R) return f[o];
pushdown(o);
int mid = (l + r) >> 1;
ll res = 0;
if (L <= mid) res = qry(ls[o], l, mid, L, R);
if (R > mid) res += qry(rs[o], mid + 1, r, L, R);
if (res >= mod) res %= mod;
return res;
}
int mge(int x, int y, int l, int r, int &k1, int &k2) {
if (!x && !y) return 0;
if (!x) {
k2 += f[y], tag[y] *= k1, f[y] *= k1;
if (k2 >= mod) k2 %= mod;
if (tag[y] >= mod) tag[y] %= mod;
if (f[y] >= mod) f[y] %= mod;
return y;
}
if (!y) {
k1 += f[x], tag[x] *= k2, f[x] *= k2;
if (k1 >= mod) k1 %= mod;
if (tag[x] >= mod) tag[x] %= mod;
if (f[x] >= mod) f[x] %= mod;
return x;
}
if (l == r) {
ll fx = f[x];
k2 += f[y], f[x] = (f[x] * k2 + f[y] * k1), k1 += fx;
if (k2 >= mod) k2 %= mod;
if (k1 >= mod) k1 %= mod;
if (f[x] >= mod) f[x] %= mod;
return x;
}
pushdown(x), pushdown(y);
int mid = (l + r) >> 1;
ls[x] = mge(ls[x], ls[y], l, mid, k1, k2);
rs[x] = mge(rs[x], rs[y], mid + 1, r, k1, k2);
pushup(x);
return x;
}
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
int mx = 0, k1, k2;
for (int v : lim[u]) mx = std::max(mx, dep[v]);
upd(rt[u], 0, n, mx, 1);
for (int v : d[u])
if (v != fa) {
dfs(v, u);
k1 = 0, k2 = qry(rt[v], 0, n, 0, dep[u]);
rt[u] = mge(rt[u], rt[v], 0, n, k1, k2);
}
}
signed main() {
int u, v;
n = read();
for (int i = 1; i < n; ++i) {
u = read(), v = read();
d[u].emplace_back(v), d[v].emplace_back(u);
}
m = read();
for (int i = 1; i <= m; ++i) {
u = read(), v = read();
lim[v].emplace_back(u);
}
dfs(1, 0);
write(qry(rt[1], 0, n, 0, 0)), he;
return 0;
}
本文来自博客园,作者:蒟蒻orz,转载请注明原文链接:https://www.cnblogs.com/orzz/p/18121961