「NOI2020」命运(树形DP+线段树合并)
Address
Solution
这种题,这种数据范围,应该很容易去想树形 DP。
树形 DP 最常见的套路就是合并子树,那就考虑,一个子树中边的权值对哪些路径有影响。例如一个点 \(u\) 的子树,里面的边对两种路径 \((x,y)\) 有影响。
一种是 \(x,y\) 都在 \(u\) 子树内,显然 \(u\) 子树内的边权就决定了这些路径能不能被满足。(每条路径上是否存在至少一条边为 \(1\))
另一种是 \(x\) 在 \(u\) 子树外,\(y\) 在 \(u\) 子树内。对于这类路径,要么路径 \((u,y)\) 上至少有一个 \(1\),要么 \((x,u)\) 上至少有一个 \(1\)。
也就是说,如果确定了 \(u\) 子树内每条边的权值,那么未被满足的路径只可能是第二种。注意到这种情况下,\(x\) 一定是 \(u\) 的祖先。那么接下来,我们需要确定 \(u\) 到根的路径上每条边的权值,使得这些路径被满足。
我们记这些还未被满足的路径中,深度最小的 \(x\) 为 \(x_0\),那么我们必须保证路径 \((x_0,u)\) 上至少有一个 \(1\)。
换句话说,我们让 \((x_0,u)\) 中,最浅的那条边为 \(1\),其余边为 \(0\),就能使得这些路径都被满足。
于是,记 \(f_{u,i}\) 表示确定 \(u\) 子树中每条边的权值,使得当 \((anc_i,anc_{i+1})\) 边为 \(1\),\(anc_{i+1}→u\) 的边都为 \(0\) 时,能满足所有和 \(u\) 子树相交的路径,有多少种方案。其中 \(anc_i\) 表示 \(u\) 的祖先中,深度为 \(i\) 的那一个。
所以 \(O(n^2)\) DP 就能写出来了:
inline void dfs2(int u, int pa)
{
int i, j;
for (i = a[u]; i < dep[u]; i++) f[u][i] = 1;
for (j = adj[u]; j; j = nxt[j])
{
int v = go[j];
if (v == pa) continue;
dfs2(v, u);
for (i = a[u]; i < dep[u]; i++)
f[u][i] = (ll)f[u][i] * (f[v][dep[u]] + f[v][i]) % mod;
}
}
其中 \(a_u\) 是以 \(u\) 为终点的路径中,起点的最小深度,没有则为 \(0\),点的深度从 \(1\) 开始。
根据 \(f\) 的定义,\(f_{u,i}\) 的 \(i\) 必须小于 \(dep_u\)。
转移则枚举 \(i\),考虑 \(f_{v,*}\) 对 \(f_{u,i}\) 的贡献。考虑 \((u,v)\) 是 \(0\) 还是 \(1\),如果是 \(1\),方案数就是 \(f_{v,dep_u}\),否则方案数是 \(f_{v,i}\)。
注意若 \(i<a_u\),则 \(f_{u,i}=0\)。为什么?因为这时候,所有以 \(u\) 为终点的路径上都还没被满足,如果你把 \(1\) 放到最深的起点的上方,那这条路径就不合法了。否则,这些以 \(u\) 为终点的路径就都会被满足。
接下来考虑怎么线段树合并优化:
根据上述代码,我们需要支持 \(3\) 种操作:
for (i = 0; i <= L; i++) f[u][i] += v;
for (i = 0; i <= L; i++) if (i < a[u] || i >= dep[u]) f[u][i] = 0;
for (i = 0; i <= L; i++) f[u][i] = f[u][i] * f[v][i];
也就是全局加,区间清零,对应位置相乘。查询的话只要单点查询。
有加有乘,考虑对线段树上的每个节点 \(u\) 维护两个标记 \(add_u,mul_u\)。其中 \(add_u\) 表示区间中的每个元素都加上 \(add_u\),\(mul_u\) 表示把 \(u\) 子树中的每个 \(add_x\) 都乘上 \(mul_u\)(除了 \(add_u\))。
单点查询就是对于线段树根到叶子的一条路径,把 \(add_x\) 乘上 \(\lceil\) \(x\) 的祖先的 \(mul\) 之积 \(\rfloor\) 的值全部加起来即可。
inline int ask(int u, int l, int r, int s, int tmp)
{
if (!u) return 0;
int res = (ll)c[u].add * tmp % mod;
if (l == r) return res;
int mid = l + r >> 1;
if (s <= mid) return S(res, ask(c[u].l, l, mid, s, M(tmp, c[u].mul)));
else return S(res, ask(c[u].r, mid + 1, r, s, M(tmp, c[u].mul)));
}
区间清零:若 \(u\) 节点不是递归边界,则把 \(u\) 的 \(add,mul\) 标记下放(注意必须同时下传左右子树,如果其中一个子树为空,则新建节点),否则直接删除子树 \(u\)。
inline void pushdown(int u)
{
if (!c[u].add && c[u].mul == 1) return;
int &l = c[u].l, &r = c[u].r;
if (!l) l = getnode();
if (!r) r = getnode();
c[l].add = ((ll)c[l].add * c[u].mul + c[u].add) % mod;
c[r].add = ((ll)c[r].add * c[u].mul + c[u].add) % mod;
c[l].mul = (ll)c[l].mul * c[u].mul % mod;
c[r].mul = (ll)c[r].mul * c[u].mul % mod;
c[u].add = 0;
c[u].mul = 1;
}
inline void cover(int &u, int l, int r, int s, int t)
{
if (l == s && r == t)
{
u = 0;
c[u].add = c[u].l = c[u].r = 0;
c[u].mul = 1;
return;
}
if (!u) return;
pushdown(u);
int mid = l + r >> 1;
if (t <= mid) cover(c[u].l, l, mid, s, t);
else if (s > mid) cover(c[u].r, mid + 1, r, s, t);
else
{
cover(c[u].l, l, mid, s, mid);
cover(c[u].r, mid + 1, r, mid + 1, t);
}
}
对应位置相乘:举个例子:两棵线段树,线段树 1 有一条路径 \(x_1,x_2,x_3,x_4,x_5\)(按从祖先到后代顺序),线段树 2 对应的路径是 \(y_1,y_2\),\(y_3\) 往下为空节点。也就是说合并它们的时候,会先访问 \(x_1,y_1\) 和 \(x_2,y_2\),当访问到 \(x_3\) 时,发现 \(y_3\) 是空,就 return
了。
设 \(x_1\sim x_5\) 的 \(add\) 分别为 \(a_1\sim a_5\),设 \(y_1\sim y_5\) 的 \(add\) 分别为 \(b_1\sim b_5\)。假设 \(mul\) 全部都是 \(1\),根据乘法分配律,设合并之后的 \(add\) 分别为 \(c_1\sim c_5\),则 \(c_1=a_1b_1,c_2=a_2b_1+a_1b_2+a_2b_2\)。以此类推,若 \(x_i,y_i\) 同时存在,则 \(c_i=a_ib_i+a_iB_{i-1}+A_{i-1}b_i\),其中 \(A,B\) 是 \(a_i,b_i\) 的前缀和,可以在递归的时候顺便记一下。
再把 \(mul\) 考虑进去,把 \(\forall i,a_i,b_i\) 都乘上根节点到它父亲的 \(mul\) 之积后,再参与 \(A,B,c\) 的计算即可。算完之后的 \(add\) 就是真的 \(add\) 了,也就是说算完之后要把这些节点的 \(mul\) 都还原为 \(1\)。
最后还有 \(x_3,y_3\),因为 \(y_3\) 是空节点,所以必须在 \(x_3\) 上做些标记。如果暴力的话,我们需要把 \(x_3,x_4,x_5\) 的 \(add\) 都乘上 \(B_2\)。但为了保证时间复杂度显然不能这么干,因此你只能把 \(x_3\) 的 \(add\) 乘上 \(B_2\),然后把 \(x_3\) 的 \(mul\) 乘上 \(B_2\)。
因为我们算完之后要把 \(x_2,y_2\) 往上的 \(mul\) 全部还原为 \(1\),而这些 \(mul\) 在还原之前的值对 \(x_4,x_5\) 是有影响的,所以记下这个影响 \(prod\),再把 \(x_3\) 的 \(mul\) 乘上 \(prod\) 就行了。(\(prod\) 就是根到 \(x_2\) 的 \(mul\) 之积)
inline int merge(int u, int v, int suma, int sumb, int l, int r, int mula, int mulb)
{
int a1 = (ll)c[u].add * mula % mod, b1 = (ll)c[v].add * mulb % mod;
if (!u || !v)
{
int x = u ^ v;
if (u)
{
c[x].add = (ll)a1 * sumb % mod;
c[x].mul = (ll)c[u].mul * sumb % mod * mula % mod;
}
else
{
c[x].add = (ll)b1 * suma % mod;
c[x].mul = (ll)c[v].mul * suma % mod * mulb % mod;
}
return x;
}
int mid = l + r >> 1;
c[u].add = (ll)a1 * b1 % mod;
plu(c[u].add, (ll)a1 * sumb % mod);
plu(c[u].add, (ll)b1 * suma % mod);
int ta = S(suma, a1), tb = S(sumb, b1),
pa = M(mula, c[u].mul), pb = M(mulb, c[v].mul);
c[u].l = merge(c[u].l, c[v].l, ta, tb, l, mid, pa, pb);
c[u].r = merge(c[u].r, c[v].r, ta, tb, mid + 1, r, pa, pb);
c[u].mul = 1;
return u;
}
时空复杂度 \(O(n\log n)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
const int N = 1e6 + 15, mod = 998244353, Q = 2e7 + 15;
struct point
{
int add, l, r, mul;
}c[Q];
int dep[N], L, n, m, adj[N], nxt[N], go[N], num, a[N], rt[N], cnt;
inline int getnode()
{
c[++cnt].mul = 1;
return cnt;
}
inline void plu(int &x, int y)
{
(x += y) >= mod && (x -= mod);
}
inline int S(int x, int y)
{
plu(x, y);
return x;
}
inline int M(int x, int y)
{
return (ll)x * y % mod;
}
inline void link(int x, int y)
{
nxt[++num] = adj[x]; adj[x] = num; go[num] = y;
nxt[++num] = adj[y]; adj[y] = num; go[num] = x;
}
inline void dfs1(int u, int pa)
{
dep[u] = dep[pa] + 1;
L = max(L, dep[u]);
for (int i = adj[u]; i; i = nxt[i])
{
int v = go[i];
if (v == pa) continue;
dfs1(v, u);
}
}
inline void pushdown(int u)
{
if (!c[u].add && c[u].mul == 1) return;
int &l = c[u].l, &r = c[u].r;
if (!l) l = getnode();
if (!r) r = getnode();
c[l].add = ((ll)c[l].add * c[u].mul + c[u].add) % mod;
c[r].add = ((ll)c[r].add * c[u].mul + c[u].add) % mod;
c[l].mul = (ll)c[l].mul * c[u].mul % mod;
c[r].mul = (ll)c[r].mul * c[u].mul % mod;
c[u].add = 0;
c[u].mul = 1;
}
inline void cover(int &u, int l, int r, int s, int t)
{
if (l == s && r == t)
{
u = 0;
c[u].add = c[u].l = c[u].r = 0;
c[u].mul = 1;
return;
}
if (!u) return;
pushdown(u);
int mid = l + r >> 1;
if (t <= mid) cover(c[u].l, l, mid, s, t);
else if (s > mid) cover(c[u].r, mid + 1, r, s, t);
else
{
cover(c[u].l, l, mid, s, mid);
cover(c[u].r, mid + 1, r, mid + 1, t);
}
}
inline int ask(int u, int l, int r, int s, int tmp)
{
if (!u) return 0;
int res = (ll)c[u].add * tmp % mod;
if (l == r) return res;
int mid = l + r >> 1;
if (s <= mid) return S(res, ask(c[u].l, l, mid, s, M(tmp, c[u].mul)));
else return S(res, ask(c[u].r, mid + 1, r, s, M(tmp, c[u].mul)));
}
inline int calc(int u, int i)
{
return ask(rt[u], 0, L, i, 1);
}
inline int merge(int u, int v, int suma, int sumb, int l, int r, int mula, int mulb)
{
int a1 = (ll)c[u].add * mula % mod, b1 = (ll)c[v].add * mulb % mod;
if (!u || !v)
{
int x = u ^ v;
if (u)
{
c[x].add = (ll)a1 * sumb % mod;
c[x].mul = (ll)c[u].mul * sumb % mod * mula % mod;
}
else
{
c[x].add = (ll)b1 * suma % mod;
c[x].mul = (ll)c[v].mul * suma % mod * mulb % mod;
}
return x;
}
int mid = l + r >> 1;
c[u].add = (ll)a1 * b1 % mod;
plu(c[u].add, (ll)a1 * sumb % mod);
plu(c[u].add, (ll)b1 * suma % mod);
int ta = S(suma, a1), tb = S(sumb, b1),
pa = M(mula, c[u].mul), pb = M(mulb, c[v].mul);
c[u].l = merge(c[u].l, c[v].l, ta, tb, l, mid, pa, pb);
c[u].r = merge(c[u].r, c[v].r, ta, tb, mid + 1, r, pa, pb);
c[u].mul = 1;
return u;
}
inline void dfs2(int u, int pa)
{
int i, j;
for (j = adj[u]; j; j = nxt[j])
{
int v = go[j];
if (v == pa) continue;
dfs2(v, u);
}
rt[u] = getnode();
c[rt[u]].add = 1;
for (j = adj[u]; j; j = nxt[j])
{
int v = go[j];
if (v == pa) continue;
int gv = calc(v, dep[u]);
plu(c[rt[v]].add, gv);
rt[u] = merge(rt[u], rt[v], 0, 0, 0, L, 1, 1);
}
if (a[u] > 0) cover(rt[u], 0, L, 0, a[u] - 1);
if (dep[u] <= L) cover(rt[u], 0, L, dep[u], L);
}
int main()
{
freopen("destiny.in", "r", stdin);
freopen("destiny.out", "w", stdout);
read(n);
int i, x, y;
for (i = 1; i < n; i++) read(x), read(y), link(x, y);
dfs1(1, 0);
read(m);
while (m--)
{
read(x); read(y);
a[y] = max(a[y], dep[x]);
}
dfs2(1, 0);
cout << calc(1, 0) << endl;
fclose(stdin);
fclose(stdout);
return 0;
}