[PKUWC2018][LOJ2537]Minimax(线段树合并)

题面

https://loj.ac/problem/2537

题解

前置知识:

首先考虑朴素做法。设\(f[u][i]\)表示点u取到第i大的权值的概率,然后进行树上dp。但是这个dp最多优化到\(O(n^2)\),并且时间和空间都是\(O(n^2)\),无法通过。

考虑使用线段树合并去实现这个dp的过程。对于树上的每一个节点u,开一棵线段树去记录\(f[u]\)。具体实现是,在原树上进行dfs,到点u时:

  1. u是叶子节点:开一条新的链作为u的线段树。
  2. u有一个子节点:u的线段树就是u儿子的线段树。
  3. u有两个子节点:u的线段树就是u两个儿子线段树的并。

合并线段树时,只需要解决一棵线段树与一棵空树怎么\(O(1)\)的合并。

假设原树上dfs到的节点是u,它的左右儿子节点是l、r。如果我们需要将x,y为根的两棵线段树合并(x,y是l,r对应的线段树上,两个相同位置的节点)而y为空,那么:

\[\forall x.l \leq i \leq x.r,f[u][i] = f[l][i](\sum\limits_{j=1}^{i}f[r][j] \times p[u] + \sum\limits_{j=i}^{wn}f[r][j]*(1-p[u])) \]

  • 其中\(x.l,x.r\)表示x节点对应的线段的左右端点,\(wn\)表示不同权值的个数。

并且由于

\[\forall x.l \leq i \leq x.r,f[r][i] =0 \]

所以有

\(\forall x.l \leq i \leq x.r,f[u][i] = f[l][i] \times (\sum\limits_{i=1}^{x.l-1}f[r][j]\times p[u] + \sum\limits_{i=x.r+1}^{wn}f[r][j]\times (1-p[u]))\)

因此,只需要在merge的过程中,时刻维护前缀和\(pre_r=\sum_{i=1}^{x.l-1}f[r][j]\)和后缀和\(suf_r=\sum_{i=x.r+1}^{wn}f[r][j]\)

如果x不为空而y为空,只需要将x打上“整体乘\((pre_r\times p[u] + suf_r \times (1-p[u]))\)”的标记即可。当然也可能出现x为空而\(y\)不为空的情况,相应地维护\(pre_l\)\(suf_l\)就可以啦。

总时间复杂度\(O(n \log n)\)

代码

#include<bits/stdc++.h>

using namespace std;

#define rg register
#define In inline
#define ll long long

const int N = 3e5;
const int TN = 2 * 19 * N;
const int mod = 998244353;
const int inv_10000 = 796898467;

namespace IO{
	In int read(){
		int s = 0,ww = 1;
		char ch = getchar();
		while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
		while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
		return s * ww;
	}
	In void write(int x){
		if(x < 0)x = -x,putchar('-');
		if(x > 9)write(x / 10);
		putchar('0' + x % 10);
	}
}
using namespace IO;

namespace ModCalc{
	In void Inc(int &x,int y){
		x += y;if(x >= mod)x -= mod;
	}
	In void Dec(int &x,int y){
		x -= y;if(x < 0)x += mod;
	}
	In void Tms(int &x,int y){
		x = 1ll * x * y % mod;
	}
	In int Add(int x,int y){
		Inc(x,y);return x;
	}
	In int Sub(int x,int y){
		Dec(x,y);return x;
	}
	In int Mul(int x,int y){
		Tms(x,y);return x;
	}
}
using namespace ModCalc;

int rt[N+5],p[N+5],q[N+5];
int D[N+5];
int n;

struct SegTree{
	int sum[TN+5],flag[TN+5],lc[TN+5],rc[TN+5];
	int cnt;
	In int newnode(){
		cnt++;
		flag[cnt] = 1;
		return cnt;
	}
	In void pushdown(int u){
		if(flag[u] == 1)return;
		int L = lc[u],R = rc[u];
		if(L){
			Tms(sum[L],flag[u]);
			Tms(flag[L],flag[u]);
		}
		if(R){
			Tms(sum[R],flag[u]);
			Tms(flag[R],flag[u]);
		}
		flag[u] = 1;
	}
	In void pushup(int u){
		sum[u] = Add(sum[lc[u]],sum[rc[u]]);
	}
	int insert(int l,int r,int x,int d){
		int u = newnode();
		if(l == r){
			sum[u] = d;
			return u;
		}
		int m = (l + r) >> 1;
		if(x <= m)lc[u] = insert(l,m,x,d);
		else rc[u] = insert(m + 1,r,x,d);
		pushup(u);
		return u; 
	}
	int merge(int u,int v,int cur,int pu,int su,int pv,int sv){	//u,v对应题解中的x,y;cur对应题解中的u;pu,su,pv,sv对应题解中的pre_l,suf_l,pre_r,suf_r
		if(!u && !v)return 0;
		if(!u){
			int x = Add(Mul(pu,p[cur]),Mul(su,q[cur]));
			Tms(sum[v],x);
			Tms(flag[v],x);
			return v;
		}
		if(!v){
			int x = Add(Mul(pv,p[cur]),Mul(sv,q[cur]));
			Tms(sum[u],x);
			Tms(flag[u],x);
			return u;
		}
		pushdown(u),pushdown(v);
		int dpu = sum[lc[u]],dsu = sum[rc[u]],dpv = sum[lc[v]],dsv = sum[rc[v]];
		lc[u] = merge(lc[u],lc[v],cur,pu,Add(su,dsu),pv,Add(sv,dsv));
		rc[u] = merge(rc[u],rc[v],cur,Add(pu,dpu),su,Add(pv,dpv),sv);
		pushup(u);
		return u;
	}
	void dfs(int u,int l,int r){
		if(l == r)D[l] = sum[u];
		pushdown(u);
		int m = (l + r) >> 1;
		if(lc[u])dfs(lc[u],l,m);
		if(rc[u])dfs(rc[u],m + 1,r);
 	}
}T;

vector<int>link[N+5];
int w[N+5],aw[N+5],wn;

void prepro(){ //离散化
	for(rg int i = 1;i <= n;i++)if(!link[i].size())aw[++wn] = w[i];
	sort(aw + 1,aw + wn + 1);
	for(rg int i = 1;i <= n;i++)if(!link[i].size())
		w[i] = lower_bound(aw + 1,aw + wn + 1,w[i]) - aw;
}

void dfs(int u){
	if(!link[u].size())rt[u] = T.insert(1,wn,w[u],1); 	
	else if(link[u].size() == 1){
		dfs(link[u][0]);
		rt[u] = rt[link[u][0]];
	}
	else{
		int l = link[u][0],r = link[u][1];
		dfs(l),dfs(r);
		rt[u] = T.merge(rt[l],rt[r],u,0,0,0,0);
	}
}

int main(){
	n = read();
	for(rg int i = 1;i <= n;i++)link[read()].push_back(i);
	for(rg int i = 1;i <= n;i++){ 
		int x = read();
		if(!link[i].size())w[i] = x;
		else p[i] = Mul(x,inv_10000),q[i] = Sub(1,p[i]);
	}
	prepro();
	dfs(1);
	T.dfs(rt[1],1,wn);
	int ans = 0;
	for(rg int i = 1;i <= wn;i++)Inc(ans,Mul(Mul(i,aw[i]),Mul(D[i],D[i])));
	write(ans),putchar('\n');
	return 0;
}
posted @ 2020-10-06 12:25  coder66  阅读(206)  评论(0编辑  收藏  举报