luogu P5298 [PKUWC2018]Minimax

luogu P5298 [PKUWC2018]Minimax

题目大意

不可描述
比较清楚就不讲了

题解

首先注意到题目中的这样一句话

保证这类点中每个结点的权值互不相同

显然线段树合并QWQ

考虑如何合并

先考虑只有一边有节点的情况

然后直接把贡献乘在那个子树上(tag)

假设把 以 x 和 y为根的两颗线段树合并

那发现如果 x, y都不为0就直接先往下做然后再 把概率加起来就行了

还是看代码吧

code:

#include<bits/stdc++.h>
#define int long long
#define mod 998244353
#define N 2000005
using namespace std;
struct A{
	int tag, d; 
}a[N << 4];
struct AA{
	int l, r, val, d;
}son[N << 4];
int P, n, ans, m, tot, b[N], ch[N << 4][2], root[N];
void TAG(int rt, int o){
	o %= mod;
	a[rt].tag = a[rt].tag * o % mod;
	a[rt].d = a[rt].d * o % mod;
}
void pushdown(int rt){
	if(a[rt].tag == 1) return;
	if(ch[rt][0]) TAG(ch[rt][0], a[rt].tag);
	if(ch[rt][1]) TAG(ch[rt][1], a[rt].tag);
	a[rt].tag = 1;
}
void insert(int &rt, int l, int r, int x){
	if(! rt) rt = ++ tot;
	a[rt].d = a[rt].tag = 1;
	if(l == r) return;
	int mid = (l + r) >> 1;
	if(x <= mid) insert(ch[rt][0], l, mid, x);
	else insert(ch[rt][1], mid + 1, r, x);
}
int merge(int x, int y, int xl, int xr, int yl, int yr){
	xl %= mod, xr %= mod, yl %= mod, yr %= mod; 
	if(!x && !y) return 0;
	if(x && y){
		pushdown(x), pushdown(y);
		int ll = a[ch[x][0]].d, rr = a[ch[y][0]].d;
		ch[x][0] = merge(ch[x][0], ch[y][0], xl, xr + a[ch[x][1]].d, yl, yr + a[ch[y][1]].d);
		ch[x][1] = merge(ch[x][1], ch[y][1], xl + ll, xr, yl + rr, yr);
		a[x].d = (a[ch[x][0]].d + a[ch[x][1]].d) % mod;
	}else{
		if(!x) swap(x, y), swap(xl, yl), swap(xr, yr);
		TAG(x, yl * P + (mod + 1 - P) * yr);
	}
	return x;
}
void dfs(int u){
	if(son[u].val){
		insert(root[u], 1, m, son[u].val);
	}else if(son[u].r){
		dfs(son[u].l), dfs(son[u].r);
		P = son[u].d;
		root[u] = merge(root[son[u].l], root[son[u].r], 0ll, 0ll, 0ll, 0ll);
	}else{
		dfs(son[u].l);
		root[u] = root[son[u].l];
	}
}
void getans(int rt, int l, int r){
	pushdown(rt);
	if(l == r){
		ans += l * b[l] % mod * a[rt].d % mod * a[rt].d % mod, ans %= mod;
		return;
	}
	int mid = (l + r) >> 1;
	if(ch[rt][0]) getans(ch[rt][0], l, mid);
	if(ch[rt][1]) getans(ch[rt][1], mid + 1, r);
}
signed main(){
	scanf("%lld", &n);
	for(int i = 1; i <= n; i ++){
		int fa;
		scanf("%lld", &fa);
		if(!son[fa].l) son[fa].l = i;
		else son[fa].r = i;
	}
	for(int i = 1; i <= n; i ++){
		int x;
		scanf("%lld", &x);
		if(son[i].l) son[i].d = x * 796898467ll % mod;
		else {
			son[i].val = x;
			b[++ m] = x;
		} 
	}
	sort(b + 1, b + 1 + m);
	
	for(int i = 1; i <= n; i ++) 
		if(son[i].val) son[i].val = lower_bound(b + 1, b + 1 + m, son[i].val) - b;
	dfs(1);
	getans(root[1], 1, m);
	printf("%lld", ans);
	return 0;
}

好题!
终于把这个史前巨坑填了

posted @ 2019-09-28 10:41  lahlah  阅读(33)  评论(0编辑  收藏  举报