填树

题意:

给出 \(n\) 个点的树,初始时每个点的权值为 \(0\)。每次可以选择树上一条路径,可以将路径上的每个点 \(i\) 赋上 \(l_i \sim r_i\) 的权值,并且满足这条路径上最大最小值之差 \(\leq K\)

求这样做后,本质不同的树的个数 以及 本质不同的树的权值之和。

路径互不相同

  • 容易发现选出的不同的路径赋值后所构成的树一定不相同。因此,每一条路径就可以看成一个序列,路径条数是 \(n^2\) 的, 这样就转化成序列问题:每次给定 \(m(\leq n)\) 个点,每个点有限制,满足最大最小值之差 \(\leq K\)

容斥

  • 考虑枚举最小值 \(i\), 对于每个点 \(j\), 容易发现最小的能取的数是 \(\max(i, l_j )\), 最大能取的数是 \(\min(i + K, r_j)\)

  • 当最大的数 \(<\) 最小的数,就取不了数。

那么想到每个点选择相互独立,就有下面的式子:

\[\sum_{i = 0}^{V} \prod_{j = 1}^{m}\max(\min(i + K, r_j) - \max(i, l_j ) + 1, 0) \]

\(V\) 是值域。

实际上,这里算出来的方案数是有可能是选不到 \(i\) 的,可以理解为求了最小值 \(\geq i\) 的方案数,要强制令其选到至少一个 \(i\),也就是要求 最小值 \(= i\) 的方案数。

不妨设 \(f(l, r)\) 表示对每个点再限制一个范围 \([l, r]\) 的本质不同的序列个数,上面实际上求的是 \(\sum_{i = 0}^{ma}f(i, i + K)\)

那我们实际要求的是 \(\sum_{i = 0}^{ma} f(i, i + K) - f(i + 1, i + K)\)

  • 直观理解为 \((=i) \gets (\geq i) - (\geq i + 1)\)

这样的时间复杂度是 \(O(n^2 \times n \times V)\)

拉格朗日插值优化

目前首要的任务是优化掉 \(V\)

  • 考虑如何计算 \(f(l, r)\),

\[f(l, r) = \prod_{j = 1}^{m}\max(\min(r, r_j) - \max(l, l_j ) + 1, 0) \]

这个 \(\min \max\) 不好搞,考虑分类讨论,下面会讨论 \(f(i, i + K)\) 中每一个点的取值。

  1. 考虑什么时候会取到 \(0\):

    发现只要 \([i, i + K]\)\([l_j, r_j]\) 没有交集,自然会取 \(0\),

    \(i + K < l_j\)\(i > r_j\)

  2. \(i + K \leq r_j, i + K > r_j\) 和 $i \leq l_j, i >l_j $考虑。

    经过大力讨论,每个点会得到一个分段函数,

    \[g(i, i + K) = \left\{ \begin{aligned} &i + K - l_j + 1 &(i < l_j)\\ &\min(K + 1, r_j - l_j + 1) &(l_j \leq i < r_j - K) \\ &r_j - i + 1 &(r_j - K \leq i)\\ \end{aligned} \right.\]

总的来说, 令 \(l' = l_j, r' = r_j - K\)

\[g(i, i + K) = \left\{ \begin{aligned} &0 &(i \leq l_j - K)\\ &i + K - l_j + 1 &(l_j - K \leq i < l')\\ &\min(K + 1, r_j - l_j + 1) &(l' \leq i < r') \\ &r_j - i + 1 &(r' \leq i < r_j)\\ &0 &(i > r_j)\\ \end{aligned} \right.\]

这里满足 \(l_j \leq r_j - K\), 如果 \(l_j > r_j - K\), 那么 \(l', r'\) 互换。

同理要减去的那部分的式子也差不多。

这里给出分段函数的式子:

\(l' = l_j - 1, r' = r_j - K\)

\[g(i + 1, i + K) = \left\{ \begin{aligned} &0 &(i \leq l_j - K)\\ &i + K - l_j + 1 &(l_j - K \leq i < l')\\ &\min(K, r_j - l_j + 1) &(l' \leq i < r') \\ &r_j - i &(r' \leq i < r_j)\\ &0 &(i > r_j)\\ \end{aligned} \right.\]

  • 不妨重新令 \(f(i) = g(i, i + K)\), \(h(i) = g(i + 1, i + K)\)

\[f(i) = \prod_{j = 1}^{m} g_i(i, i + K) \]

\[h(i) = \prod_{j = 1}^{m} g_i(i + 1, i + K) \]

  • 方案数为:

\[\sum_{i = 0}^{V} f(i) - h(i) \]

  • 每个 \(g\) 是个一次函数,这个 \(f\) 就是 \(m\) 次的多项式。

在令 \(f\) 做前缀和,它就是 \(m + 1\) 次的多项式。

只要用连续 \(m + 2\) 个点就能用 lagrange插值 计算出一段的和,

由于每个 \(g\)\(5\)段,对于所有点的所有分段点重排后,在每对相邻点计算区间的和即可。

时间复杂度 \(O(n^4)\)

  • 对于第二问树的带权和,

考虑一个点的贡献,计算剩下点能构成的合法路径条数,乘上每个点能选出的值的和。

对每个点求和即是答案,计算方式也和上面差不多。

分段函数改成等差数列求和即可。

60分记录在此

树形dp

现在瓶颈就在枚举出所有路径,

必须得将路径共同考虑。

先考虑用 \([l, r]\) 限制整个树的取值,要求不同的树的个数 \(\text{ans1}\), 权值和 \(\text{ans2}\)

  • \(f_u\) 表示所有子树中的叶子到 \(u\) 的路径不同赋值方案的和,
    \(h_u\) 表示所有子树中的叶子到 \(u\) 的路径不同赋值方案的点权和。

  • 对于每个点初始状态,

    \(f_u = b_1 = \min(r, r_u) - \max(l, l_u) + 1\), 表示根到根的路径的不同方案。

    \(h_u = b_2 = \frac{(\max(l, l_u) + \min(r, r_u)) \times f_u}{2}\), 表示点权和。

    分别加入 \(\text{ans1}\), \(\text{ans2}\)

  • 对最浅点是该点的路径计算,考虑像背包合并子树一样的过程,枚举子树 \(v\)

    • 每次都能和之前的路径连接上:对 \(\text{ans1}\) 贡献 \(f_u \times f_v\)

    • 之前的路径还能接上新的路径,两条路径贡献分别增多:对于 \(\text{ans2}\) 贡献 \(h_u \times f_v + f_v \times f_u\)

    • 对于 \(f_u\)\(f_v\)\(u\) 点形成新路径: \(f_u += f_v \times b_1\)

    • 对于 \(h_u\)\(h_v\) 多出的贡献以及 \(u\)点多出的贡献: \(h_u += h_v \times b_1 + f_v \times b_2\)

  • 这时还是枚举最小值 \(i\), 用 \(dp(l, r)\) 表示限制 \([l, r]\)下,最终答案\(\text{ans1}\)

    容斥: \(dp(i, i + K) - dp(i + 1, i + K)\)

发现 \(f_u = g_u(i + 1, i + K)\),这个 \(g_u(i + 1, i + K)\) 是前前面推的分段函数。

仍然考虑拉格朗日插值,按照转折点分段。

这时时间复杂度 \(O(n^3)\)

int n, m, K; 
int L[MAXN], R[MAXN], l[MAXN], r[MAXN], d[MAXN]; 

vector<int> e[MAXN]; 

int inv[MAXN]; 
int lagrange(vector<pair<int, int>> &v, int x) {
	int sum = 0, n = v.size();
	for (auto i : v) 
		if (i.first == x) 
			return i.second; 
	int prod1 = 1, p = n & 1 ? 1 : mod - 1; 
	for (int i = 0; i < n; i ++)
		prod1 = mul(prod1, dec(x, v[i].first)); 
	for (int i = 0; i < n; i ++) {
		sum = add(sum, mul(v[i].second, mul(mul(prod1, qpow(dec(x, v[i].first), mod - 2)), mul(mul(inv[i], inv[n - i - 1]), p))));	
		p = mod - p; 
	}
	return sum; 
}

int f[MAXN], g[MAXN]; 
int ans1, ans2, val1[MAXN], val2[MAXN];
void dfs(int u, int fa) {
	ans1 = add(ans1, (f[u] = val1[u]));
	ans2 = add(ans2, (g[u] = val2[u])); 
	for (int v : e[u]) {
		if (v == fa) continue; 
		dfs(v, u); 
		ans1 = add(ans1, mul(f[u], f[v])); 
		ans2 = add(ans2, add(mul(g[u], f[v]), mul(g[v], f[u]))); 
		f[u] = add(f[u], mul(val1[u], f[v])); 
		g[u] = add(g[u], add(mul(val1[u], g[v]), mul(val2[u], f[v]))); 
	}
}
inline int G(int l, int r) {
	if (l > r) return 0; 
	return 1ll * (l + r) * (r - l + 1) / 2 % mod; 
}
pair<int, int> calc(int l, int r) {
	ans1 = ans2 = 0;
	for (int i = 1; i <= n; i ++) {
		if (R[i] < l || r < L[i]) val1[i] = val2[i] = 0;
		else {
			val1[i] = min(r, R[i]) - max(l, L[i]) + 1; 
			val2[i] = G(max(l, L[i]), min(r, R[i])); 
		}
	}
	dfs(1, 0);
	return make_pair(ans1, ans2); 
}
pair<int, int> calc(int x) {
	auto a = calc(x, x + K), b = calc(x + 1, x + K);
	return make_pair(dec(a.first, b.first), dec(a.second, b.second));  
} 
pair<int, int> solve(int l, int r) {
	vector<pair<int, int>> v1, v2;
	int sum1 = 0, sum2 = 0; 
	for (int i = l; i <= min(r, l + n + 2); i ++) {
		auto p = calc(i); 
		sum1 = add(sum1, p.first);
		sum2 = add(sum2, p.second); 
		v1.emplace_back(i, sum1); 
		v2.emplace_back(i, sum2); 
	}
	return make_pair(lagrange(v1, r), lagrange(v2, r)); 
}

int main() { 
	cin >> n >> K;
	int p = 1; 
	for (int i = 1; i <= n + 3; i ++) p = mul(p, i);
	inv[n + 3] = qpow(p, mod - 2); 
	for (int i = n + 2; i >= 0; i --)
		inv[i] = mul(inv[i + 1], i + 1); 
		
	int mi = INF, ma = 0; 
	for (int i = 1; i <= n; i ++) {
		cin >> L[i] >> R[i];
		mi = min(mi, L[i]); 
		ma = max(ma, R[i]);
		d[++ m] = L[i]; 
		d[++ m] = max(L[i] - 1, 0);
		d[++ m] = max(L[i] - K, 0); 
		d[++ m] = max(R[i] - K, 0);  
		d[++ m] = R[i]; 
	}
	sort(d + 1, d + m + 1); 
	m = unique(d + 1, d + m + 1) - d - 1; 
		 
	for (int i = 1; i < n; i ++) {
		int u, v;
		cin >> u >> v;
		e[u].emplace_back(v); 
		e[v].emplace_back(u); 
	}
	
	int ans1 = 0, ans2 = 0; 
	for (int i = 2; i <= m; i ++) {
		auto ans = solve(d[i - 1], d[i] - 1); 
		ans1 = add(ans1, ans.first);
		ans2 = add(ans2, ans.second); 
	} 
	auto ans = solve(d[m], d[m]); 
	ans1 = add(ans1, ans.first); 
	ans2 = add(ans2, ans.second);  
	cout << ans1 << endl << ans2 << endl; 
	return 0;
}
posted @ 2022-04-20 21:33  qjbqjb  阅读(68)  评论(2编辑  收藏  举报