LOJ#2537. 「PKUWC2018」Minimax 线段树合并

$O(n^2)$ 的式子是好列的,然后我们发现这是一个关于前后缀的转移.   

用线段树合并优化这一过程.   

具体地,分别维护 $x,y$ 的后缀和.   

这里要注意:由于这道题中两个不同子树肯定没有交集,所以在线段树合并的时候肯定会合并到一个点,使得两个树中一个为空.  

然后由于另一个是空的,就没有合并的必要了,这样整个区间乘的就是一个相同的数了.  

这样就只需要维护一个乘法标记就行了.

code: 

#include <bits/stdc++.h>   
#define ll long long 
#define mod 998244353 
#define N 300008  
#define lson s[x].ls 
#define rson s[x].rs      
#define setIO(s) freopen(s".in","r",stdin)  
using namespace std;        
int fa[N],cn,n,tot,ans;  
int ch[N][2],val[N],perc[N],A[N],v[N];   
int qpow(int x,int y) 
{
	int tmp=1;  
	for(;y;y>>=1,x=(ll)x*x%mod)   
		if(y&1) tmp=(ll)tmp*x%mod;  
	return tmp;  
}
inline int INV(int x) { return qpow(x,mod-2); }                  
struct data 
{
	int ls,rs; 
	ll sum,tag;   
	// data(){ ls=rs=sum=0,tag=1; }  
}s[N*50];   
int rt[N];  
inline int newnode() { return ++tot; }   
inline void pushup(int x) { s[x].sum=(ll)(s[lson].sum+s[rson].sum)%mod; }    
void update(int &x,int l,int r,int p,int v) 
{
	if(!x) x=newnode(),s[x].tag=1;     
	if(l==r) { s[x].sum=s[x].tag=v; return; }   
	int mid=(l+r)>>1;  
	if(p<=mid)  update(lson,l,mid,p,v);  
	else update(rson,mid+1,r,p,v);  
	pushup(x);            
}   
inline void mark(int x,ll v) 
{
	s[x].tag=(ll)s[x].tag*v%mod;  
	s[x].sum=(ll)s[x].sum*v%mod;   
}       
inline void pushdown(int x) 
{
	if(s[x].tag!=1) 
	{
		if(lson) mark(lson,s[x].tag);            
		if(rson) mark(rson,s[x].tag);  
		s[x].tag=1;  
	}
}
// s1-> 小的 
// s2-> 多的   
int merge(int x,int y,ll det,ll x1,ll x2,ll y1,ll y2) 
{   
	if(!x&&!y) return 0; 
	if(!x) 
	{         
		ll up=(ll)((ll)(1-det+mod)%mod*x2%mod+(ll)det*x1%mod)%mod;    
		mark(y,up);   
		return y;  
	} 
	if(!y) 
	{
		ll up=(ll)((ll)(1-det+mod)%mod*y2%mod+(ll)det*y1%mod)%mod;    
		mark(x,up);  
		return x;  
	} 
	int now=newnode();   
	pushdown(x),pushdown(y);  
	int xr=(ll)(x2+s[s[x].rs].sum)%mod;      
	int yr=(ll)(y2+s[s[y].rs].sum)%mod;   
	int xl=(ll)(x1+s[s[x].ls].sum)%mod;  
	int yl=(ll)(y1+s[s[y].ls].sum)%mod; 
	s[now].tag=1;               
	s[now].ls=merge(s[x].ls,s[y].ls,det,x1,xr,y1,yr);    
	s[now].rs=merge(s[x].rs,s[y].rs,det,xl,x2,yl,y2);   
	pushup(now);  
	return now;  
}
void dfs(int x) 
{
	int l=ch[x][0],r=ch[x][1];  
	if(!l) update(rt[x],1,cn,v[x],1);  
	else if(!r) dfs(l),rt[x]=rt[l];                                                     
	else dfs(l),dfs(r),rt[x]=merge(rt[l],rt[r],1ll*perc[x],0,0,0,0);    
}          
void output(int x,int l,int r) 
{  
	if(!x) return;   
	if(l==r) 
	{
		(ans+=(ll)l*A[l]%mod*s[x].sum%mod*s[x].sum%mod)%=mod;   
		return; 
	}  
	int mid=(l+r)>>1;     
	pushdown(x);  
	output(s[x].ls,l,mid);   
	output(s[x].rs,mid+1,r);  
}
int main() 
{ 
	// setIO("input");               
	scanf("%d",&n);  
	for(int i=1;i<=n;++i) 
	{   
		scanf("%d",&fa[i]);   
		if(ch[fa[i]][0]) ch[fa[i]][1]=i;  
		else ch[fa[i]][0]=i;  
	}   
	for(int i=1;i<=n;++i) 
	{    
		int a;
		scanf("%d",&a);  
		if(!ch[i][0]) val[i]=a,A[++cn]=val[i];  
		else perc[i]=(ll)a*INV(10000)%mod;     
	}    
	sort(A+1,A+1+cn);               
	for(int i=1;i<=n;++i)  if(!ch[i][0]) v[i]=lower_bound(A+1,A+1+cn,val[i])-A;   
	dfs(1),output(rt[1],1,cn),printf("%d\n",ans); 
	return 0; 
}	

  

posted @ 2020-05-29 13:59  EM-LGH  阅读(213)  评论(0编辑  收藏  举报