「NOI2020」命运(树形DP+线段树合并)

Address

LOJ3340

Solution

这种题,这种数据范围,应该很容易去想树形 DP。

树形 DP 最常见的套路就是合并子树,那就考虑,一个子树中边的权值对哪些路径有影响。例如一个点 \(u\) 的子树,里面的边对两种路径 \((x,y)\) 有影响。

一种是 \(x,y\) 都在 \(u\) 子树内,显然 \(u\) 子树内的边权就决定了这些路径能不能被满足。(每条路径上是否存在至少一条边为 \(1\)

另一种是 \(x\)\(u\) 子树外,\(y\)\(u\) 子树内。对于这类路径,要么路径 \((u,y)\) 上至少有一个 \(1\),要么 \((x,u)\) 上至少有一个 \(1\)

也就是说,如果确定了 \(u\) 子树内每条边的权值,那么未被满足的路径只可能是第二种。注意到这种情况下,\(x\) 一定是 \(u\) 的祖先。那么接下来,我们需要确定 \(u\) 到根的路径上每条边的权值,使得这些路径被满足。

我们记这些还未被满足的路径中,深度最小的 \(x\)\(x_0\),那么我们必须保证路径 \((x_0,u)\) 上至少有一个 \(1\)

换句话说,我们让 \((x_0,u)\) 中,最浅的那条边为 \(1\),其余边为 \(0\),就能使得这些路径都被满足。

于是,记 \(f_{u,i}\) 表示确定 \(u\) 子树中每条边的权值,使得当 \((anc_i,anc_{i+1})\) 边为 \(1\)\(anc_{i+1}→u\) 的边都为 \(0\) 时,能满足所有和 \(u\) 子树相交的路径,有多少种方案。其中 \(anc_i\) 表示 \(u\) 的祖先中,深度为 \(i\) 的那一个。

所以 \(O(n^2)\) DP 就能写出来了:

inline void dfs2(int u, int pa)
{
	int i, j;
	for (i = a[u]; i < dep[u]; i++) f[u][i] = 1;
	for (j = adj[u]; j; j = nxt[j])
	{
		int v = go[j];
		if (v == pa) continue;
		dfs2(v, u);
		for (i = a[u]; i < dep[u]; i++)
			f[u][i] = (ll)f[u][i] * (f[v][dep[u]] + f[v][i]) % mod;
	}
}

其中 \(a_u\) 是以 \(u\) 为终点的路径中,起点的最小深度,没有则为 \(0\),点的深度从 \(1\) 开始。

根据 \(f\) 的定义,\(f_{u,i}\)\(i\) 必须小于 \(dep_u\)

转移则枚举 \(i\),考虑 \(f_{v,*}\)\(f_{u,i}\) 的贡献。考虑 \((u,v)\)\(0\) 还是 \(1\),如果是 \(1\),方案数就是 \(f_{v,dep_u}\),否则方案数是 \(f_{v,i}\)

注意若 \(i<a_u\),则 \(f_{u,i}=0\)。为什么?因为这时候,所有以 \(u\) 为终点的路径上都还没被满足,如果你把 \(1\) 放到最深的起点的上方,那这条路径就不合法了。否则,这些以 \(u\) 为终点的路径就都会被满足。

接下来考虑怎么线段树合并优化:

根据上述代码,我们需要支持 \(3\) 种操作:

for (i = 0; i <= L; i++) f[u][i] += v;
for (i = 0; i <= L; i++) if (i < a[u] || i >= dep[u]) f[u][i] = 0;
for (i = 0; i <= L; i++) f[u][i] = f[u][i] * f[v][i];

也就是全局加,区间清零,对应位置相乘。查询的话只要单点查询。

有加有乘,考虑对线段树上的每个节点 \(u\) 维护两个标记 \(add_u,mul_u\)。其中 \(add_u\) 表示区间中的每个元素都加上 \(add_u\)\(mul_u\) 表示把 \(u\) 子树中的每个 \(add_x\) 都乘上 \(mul_u\)(除了 \(add_u\))。

单点查询就是对于线段树根到叶子的一条路径,把 \(add_x\) 乘上 \(\lceil\) \(x\) 的祖先的 \(mul\) 之积 \(\rfloor\) 的值全部加起来即可。

inline int ask(int u, int l, int r, int s, int tmp)
{
	if (!u) return 0;
	int res = (ll)c[u].add * tmp % mod;
	if (l == r) return res;
	int mid = l + r >> 1;
	if (s <= mid) return S(res, ask(c[u].l, l, mid, s, M(tmp, c[u].mul)));
	else return S(res, ask(c[u].r, mid + 1, r, s, M(tmp, c[u].mul)));
}

区间清零:若 \(u\) 节点不是递归边界,则把 \(u\)\(add,mul\) 标记下放(注意必须同时下传左右子树,如果其中一个子树为空,则新建节点),否则直接删除子树 \(u\)

inline void pushdown(int u)
{
	if (!c[u].add && c[u].mul == 1) return;
	int &l = c[u].l, &r = c[u].r;
	if (!l) l = getnode();
	if (!r) r = getnode();
	c[l].add = ((ll)c[l].add * c[u].mul + c[u].add) % mod;
	c[r].add = ((ll)c[r].add * c[u].mul + c[u].add) % mod;
	c[l].mul = (ll)c[l].mul * c[u].mul % mod;
	c[r].mul = (ll)c[r].mul * c[u].mul % mod;
	c[u].add = 0;
	c[u].mul = 1;
}

inline void cover(int &u, int l, int r, int s, int t)
{
	if (l == s && r == t)
	{
		u = 0;
		c[u].add = c[u].l = c[u].r = 0;
		c[u].mul = 1;
		return;
	}
	if (!u) return;
	pushdown(u);
	int mid = l + r >> 1;
	if (t <= mid) cover(c[u].l, l, mid, s, t);
	else if (s > mid) cover(c[u].r, mid + 1, r, s, t);
	else
	{
		cover(c[u].l, l, mid, s, mid);
		cover(c[u].r, mid + 1, r, mid + 1, t);
	}
}

对应位置相乘:举个例子:两棵线段树,线段树 1 有一条路径 \(x_1,x_2,x_3,x_4,x_5\)(按从祖先到后代顺序),线段树 2 对应的路径是 \(y_1,y_2\)\(y_3\) 往下为空节点。也就是说合并它们的时候,会先访问 \(x_1,y_1\)\(x_2,y_2\),当访问到 \(x_3\) 时,发现 \(y_3\) 是空,就 return 了。

\(x_1\sim x_5\)\(add\) 分别为 \(a_1\sim a_5\),设 \(y_1\sim y_5\)\(add\) 分别为 \(b_1\sim b_5\)。假设 \(mul\) 全部都是 \(1\),根据乘法分配律,设合并之后的 \(add\) 分别为 \(c_1\sim c_5\),则 \(c_1=a_1b_1,c_2=a_2b_1+a_1b_2+a_2b_2\)。以此类推,若 \(x_i,y_i\) 同时存在,则 \(c_i=a_ib_i+a_iB_{i-1}+A_{i-1}b_i\),其中 \(A,B\)\(a_i,b_i\) 的前缀和,可以在递归的时候顺便记一下。

再把 \(mul\) 考虑进去,把 \(\forall i,a_i,b_i\) 都乘上根节点到它父亲的 \(mul\) 之积后,再参与 \(A,B,c\) 的计算即可。算完之后的 \(add\) 就是真的 \(add\) 了,也就是说算完之后要把这些节点的 \(mul\) 都还原为 \(1\)

最后还有 \(x_3,y_3\),因为 \(y_3\) 是空节点,所以必须在 \(x_3\) 上做些标记。如果暴力的话,我们需要把 \(x_3,x_4,x_5\)\(add\) 都乘上 \(B_2\)。但为了保证时间复杂度显然不能这么干,因此你只能把 \(x_3\)\(add\) 乘上 \(B_2\),然后把 \(x_3\)\(mul\) 乘上 \(B_2\)

因为我们算完之后要把 \(x_2,y_2\) 往上的 \(mul\) 全部还原为 \(1\),而这些 \(mul\) 在还原之前的值对 \(x_4,x_5\) 是有影响的,所以记下这个影响 \(prod\),再把 \(x_3\)\(mul\) 乘上 \(prod\) 就行了。(\(prod\) 就是根到 \(x_2\)\(mul\) 之积)

inline int merge(int u, int v, int suma, int sumb, int l, int r, int mula, int mulb)
{
	int a1 = (ll)c[u].add * mula % mod, b1 = (ll)c[v].add * mulb % mod;
	if (!u || !v) 
	{
		int x = u ^ v;
		if (u)
		{
			c[x].add = (ll)a1 * sumb % mod;
			c[x].mul = (ll)c[u].mul * sumb % mod * mula % mod;
		}
		else 
		{
			c[x].add = (ll)b1 * suma % mod;
			c[x].mul = (ll)c[v].mul * suma % mod * mulb % mod;
		}
		return x;
	}
	int mid = l + r >> 1;
	c[u].add = (ll)a1 * b1 % mod;
	plu(c[u].add, (ll)a1 * sumb % mod);
	plu(c[u].add, (ll)b1 * suma % mod);
	int ta = S(suma, a1), tb = S(sumb, b1), 
		pa = M(mula, c[u].mul), pb = M(mulb, c[v].mul);
	c[u].l = merge(c[u].l, c[v].l, ta, tb, l, mid, pa, pb);
	c[u].r = merge(c[u].r, c[v].r, ta, tb, mid + 1, r, pa, pb);
	c[u].mul = 1;
	return u;
}

时空复杂度 \(O(n\log n)\)

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

template <class t>
inline void read(t & res)
{
	char ch;
	while (ch = getchar(), !isdigit(ch));
	res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
		res = res * 10 + (ch ^ 48);
}

const int N = 1e6 + 15, mod = 998244353, Q = 2e7 + 15;

struct point
{
	int add, l, r, mul;
}c[Q];
int dep[N], L, n, m, adj[N], nxt[N], go[N], num, a[N], rt[N], cnt;

inline int getnode()
{
	c[++cnt].mul = 1;
	return cnt;
}

inline void plu(int &x, int y)
{
	(x += y) >= mod && (x -= mod);
}

inline int S(int x, int y)
{
	plu(x, y);
	return x;
}

inline int M(int x, int y)
{
	return (ll)x * y % mod;
}

inline void link(int x, int y)
{
	nxt[++num] = adj[x]; adj[x] = num; go[num] = y;
	nxt[++num] = adj[y]; adj[y] = num; go[num] = x;
}

inline void dfs1(int u, int pa)
{
	dep[u] = dep[pa] + 1;
	L = max(L, dep[u]);
	for (int i = adj[u]; i; i = nxt[i])
	{
		int v = go[i];
		if (v == pa) continue;
		dfs1(v, u); 
	}
}

inline void pushdown(int u)
{
	if (!c[u].add && c[u].mul == 1) return;
	int &l = c[u].l, &r = c[u].r;
	if (!l) l = getnode();
	if (!r) r = getnode();
	c[l].add = ((ll)c[l].add * c[u].mul + c[u].add) % mod;
	c[r].add = ((ll)c[r].add * c[u].mul + c[u].add) % mod;
	c[l].mul = (ll)c[l].mul * c[u].mul % mod;
	c[r].mul = (ll)c[r].mul * c[u].mul % mod;
	c[u].add = 0;
	c[u].mul = 1;
}

inline void cover(int &u, int l, int r, int s, int t)
{
	if (l == s && r == t)
	{
		u = 0;
		c[u].add = c[u].l = c[u].r = 0;
		c[u].mul = 1;
		return;
	}
	if (!u) return;
	pushdown(u);
	int mid = l + r >> 1;
	if (t <= mid) cover(c[u].l, l, mid, s, t);
	else if (s > mid) cover(c[u].r, mid + 1, r, s, t);
	else
	{
		cover(c[u].l, l, mid, s, mid);
		cover(c[u].r, mid + 1, r, mid + 1, t);
	}
}

inline int ask(int u, int l, int r, int s, int tmp)
{
	if (!u) return 0;
	int res = (ll)c[u].add * tmp % mod;
	if (l == r) return res;
	int mid = l + r >> 1;
	if (s <= mid) return S(res, ask(c[u].l, l, mid, s, M(tmp, c[u].mul)));
	else return S(res, ask(c[u].r, mid + 1, r, s, M(tmp, c[u].mul)));
}

inline int calc(int u, int i)
{
	return ask(rt[u], 0, L, i, 1);
}

inline int merge(int u, int v, int suma, int sumb, int l, int r, int mula, int mulb)
{
	int a1 = (ll)c[u].add * mula % mod, b1 = (ll)c[v].add * mulb % mod;
	if (!u || !v) 
	{
		int x = u ^ v;
		if (u)
		{
			c[x].add = (ll)a1 * sumb % mod;
			c[x].mul = (ll)c[u].mul * sumb % mod * mula % mod;
		}
		else 
		{
			c[x].add = (ll)b1 * suma % mod;
			c[x].mul = (ll)c[v].mul * suma % mod * mulb % mod;
		}
		return x;
	}
	int mid = l + r >> 1;
	c[u].add = (ll)a1 * b1 % mod;
	plu(c[u].add, (ll)a1 * sumb % mod);
	plu(c[u].add, (ll)b1 * suma % mod);
	int ta = S(suma, a1), tb = S(sumb, b1), 
		pa = M(mula, c[u].mul), pb = M(mulb, c[v].mul);
	c[u].l = merge(c[u].l, c[v].l, ta, tb, l, mid, pa, pb);
	c[u].r = merge(c[u].r, c[v].r, ta, tb, mid + 1, r, pa, pb);
	c[u].mul = 1;
	return u;
}

inline void dfs2(int u, int pa)
{
	int i, j;
	for (j = adj[u]; j; j = nxt[j])
	{
		int v = go[j];
		if (v == pa) continue;
		dfs2(v, u);
	}
	rt[u] = getnode();
	c[rt[u]].add = 1;
	for (j = adj[u]; j; j = nxt[j])
	{
		int v = go[j];
		if (v == pa) continue;
		int gv = calc(v, dep[u]);
		plu(c[rt[v]].add, gv);
		rt[u] = merge(rt[u], rt[v], 0, 0, 0, L, 1, 1);
	}
	if (a[u] > 0) cover(rt[u], 0, L, 0, a[u] - 1);
	if (dep[u] <= L) cover(rt[u], 0, L, dep[u], L);
}

int main()
{
	freopen("destiny.in", "r", stdin);
	freopen("destiny.out", "w", stdout);
	read(n);
	int i, x, y;
	for (i = 1; i < n; i++) read(x), read(y), link(x, y);
	dfs1(1, 0);
	read(m);
	while (m--)
	{
		read(x); read(y);
		a[y] = max(a[y], dep[x]);
	}
	dfs2(1, 0);
	cout << calc(1, 0) << endl;
	fclose(stdin);
	fclose(stdout);
	return 0;
}
posted @ 2021-11-18 23:12  花淇淋  阅读(53)  评论(0编辑  收藏  举报