[PKUWC2018]Minimax

题意:

\(n(\leq3\times 10^5)\) 个结点的二叉树,叶子结点有互不相同的权值 \(v\), 非叶子节点有概率 \(p\) 取两个子节点最大值,否则取最小值。

\(\sum_{i = 1}^{m} i \cdot v_i \cdot d_i^2\), 即 \(m\) 种可能的结果中,第 \(i\) 小的权值 \(v_i\) * \(i\) * 根节点是 \(v_i\) 的概率的平方。

树形dp

由于权值互不相同,所有的可能个数就是叶子节点的个数,只要求出根结点是出现某个权值的概率, 问题就可以解决了。

因为需要比较大小,得确定当前节点的权值。

方程

可以设出方程 \(f_{u, x}\), 表示当前节点 \(u\) 的权值是 第 \(x\) 大的权值 \(val\) 的概率。

初始状态

对于每个叶子节点 \(u\),这个点是 \(v_{u}\) 的概率肯定是 \(1\)

因此: \(f_{u, x} = 1\)

转移

分类讨论当前非叶子节点 \(u\) 的权值 \(val\) 从哪里来,令 \(val\) 是第 \(x\) 大的权值。

  1. 如果只有一个儿子,那就一定是从这个儿子来:

  2. 考虑从左节点来,有最大值和最小值两种途径。

    1. 最大值:

      \[f_{u, x} = f_{l, x} \times \sum_{i = 1}^{x - 1} f_{r, i} \times p \]

    2. 最小值:

      \[f_{u, x} = f_{l, x} \times \sum_{i = x + 1}^{m} f_{r, i} \times (1 - p) \]

对于右节点同理。

就有:

\[f_{u, x} = f_{l, x} \sum_{i = 1}^{x - 1} f_{r, i} \times p + f_{l, x} \sum_{i = x + 1}^{m} f_{r, i} \times (1 - p) + f_{r, x} \sum_{i = 1}^{x - 1} f_{l, i} \times p + f_{r, x} \sum_{i = x + 1}^{m} f_{l, i} \times (1 - p) \]

化简:

\[f_{u, x} = f_{l, x}(\sum_{i = 1}^{x - 1} f_{r, i} \times p + \sum_{i = x + 1}^{m} f_{r, i} \times (1 - p)) + f_{r, x} (\sum_{i = 1}^{x - 1} f_{l, i} \times p + \sum_{i = x + 1}^{m} f_{l, i} \times (1 - p)) \]

分析

时间复杂度 \(O(n \times m)\), 空间复杂度 \(O(n \times m)\)

对于空间,只有根节点才有用,没必要开那么大,可以用到类似滚动数组的玩意,但要在树上?

对于时间,用前后缀和优化的转移是 \(O(m)\)的,还是不够好?

线段树合并

对于上述问题

线段树合并的空间是 \(O(m \log m)\) 的,因为加入新点的空间花费是 \(O(\log m)\), 最多插入 \(O(m)\) 次。

转移的形式像对应点相加,在乘一个值 \(P\), 简单表示为 \(f_{i} = f1_{i} \times P_1 + f2_i \times P_2\)

而这个 \(P1, P2\) 可以在合并的时候顺路求出, 只用维护区间和就能做到。

如果遇到某个儿子的结点为空,其中的 \(f_{l, x}\)\(f_{r, x}\), 那么打一个乘法的懒标记即可。

线段树合并时间复杂度均摊是 \(O(n \log m)\) 的。

代码

#include<bits/stdc++.h>

using namespace std;

using ll = long long;
const int MAXN = 300010;
const int INF = 0x7fffffff;
const int mod = 998244353;

template <typename T>
void Read(T &x) {
	x = 0; T f = 1; char a = getchar();
	for(; a < '0' || '9' < a; a = getchar()) if (a == '-') f = -f;
	for(; '0' <= a && a <= '9'; a = getchar()) x = (x * 10) + (a ^ 48);
	x *= f;
}

int add(int a, int b) {
	int c = a + b;
	if (c >= mod) c -= mod;
	if (c < 0) c += mod;
	return c; 
} 
int mul(int a, int b) {
	return 1ll * a * b % mod; 
}
int qpow(int a, int b) {
	int sum(1);
	while(b) {
		if (b & 1) sum = mul(sum, a);
		a = mul(a, a);
		b >>= 1;
	}
	return sum; 
}

int n; 
vector<int> e[MAXN];
int val[MAXN]; 

int cnt;
int L[MAXN << 5], R[MAXN << 5], Mul[MAXN << 5], sum[MAXN << 5]; 

int newnode() {
	int rt = ++ cnt;
	L[rt] = R[rt] = sum[rt] = 0;
	Mul[rt] = 1;
	return rt; 
}

void pushup(int rt) {
	sum[rt] = add(sum[L[rt]], sum[R[rt]]);  
}
void _mul(int rt, int val) {
	if (!rt) return ; 
	sum[rt] = mul(sum[rt], val); 
	Mul[rt] = mul(Mul[rt], val); 
}
void pushdown(int rt) {
	if (!rt) return ;
	if (L[rt]) _mul(L[rt], Mul[rt]);
	if (R[rt]) _mul(R[rt], Mul[rt]);
	Mul[rt] = 1; 
}
int update(int _L, int _C, int l, int r, int rt) {
	if (!rt) rt = newnode(); 
	if (l == r) {
		sum[rt] = _C;
		return rt; 
	}
	pushdown(rt); 
	int m = (l + r) >> 1; 
	if (_L <= m) L[rt] = update(_L, _C, l, m, L[rt]);
	else R[rt]= update(_L, _C, m + 1, r, R[rt]);
	pushup(rt); 
	return rt; 
}
int merge(int x, int y, int sum1, int sum2, int p, int l, int r) {
	if (!x && !y) return 0; 
	if (!x) {
		_mul(y, sum2);
		return y;
	} 
	if (!y) {
		_mul(x, sum1);
		return x; 
	}
	pushdown(x), pushdown(y); 
	int m = (l + r) >> 1; 
	int ls1 = sum[L[x]], ls2 = sum[L[y]], rs1 = sum[R[x]], rs2 = sum[R[y]]; 
	L[x] = merge(L[x], L[y], add(sum1, mul(add(1, -p), rs2)), add(sum2, mul(add(1, -p), rs1)), p, l, m);
	R[x] = merge(R[x], R[y], add(sum1, mul(p, ls2)), add(sum2, mul(p, ls1)), p, m + 1, r);
	pushup(x); 
	return x;   
}

int len;
int b[MAXN];

int root[MAXN]; 
void dfs(int u) {
	if (!e[u].size()) {
		root[u] = update(lower_bound(b + 1, b + len + 1, val[u]) - b, 1, 1, len, root[u]); 
	} else if(e[u].size() == 1) {
		int son = e[u][0];
		dfs(son);
		root[u] = root[son];	
	} else {
		int l = e[u][0], r = e[u][1]; 
		dfs(l), dfs(r); 
		root[u] = merge(root[l], root[r], 0, 0, val[u], 1, len);
	}
}

int query(int l, int r, int rt) {
	if (!rt) return 0;
	if (l == r) {
		return mul(mul(l, b[l]), qpow(sum[rt], 2)); 
	}
	int m = (l + r) >> 1;
	pushdown(rt); 
	return add(query(l, m, L[rt]), query(m + 1, r, R[rt])); 
}
 
int main() {
	Read(n);
	for (int i = 1; i <= n; i ++) {
		int fa;
		Read(fa);
		e[fa].emplace_back(i); 
	}
	
	for (int i = 1; i <= n; i ++) {
		Read(val[i]); 
		if (e[i].size())
			val[i] = mul(val[i], qpow(10000, mod - 2)); 
		else 
			b[++ len] = val[i]; 
	}
	sort(b + 1, b + len + 1);
	len = unique(b + 1, b + len + 1) - b - 1;
	dfs(1); 
	cout << query(1, len, root[1]); 
	return 0;
} 
posted @ 2022-02-28 14:35  qjbqjb  阅读(49)  评论(0编辑  收藏  举报