LOJ2537. 「PKUWC2018」Minimax [DP,线段树合并]

传送门

思路

首先有一个\(O(n^2)\)的简单DP:设\(dp_{x,w}\)\(x\)的权值为\(w\)的概率。

假设\(w\)来自\(v1\)的子树,那么有

\[dp_{x,w}=dp_{v1,w}\times (p\times \sum_{w'>w}dp_{v2,w'}+(1-p)\sum_{w'<w}dp_{v2,w'}) \]

其中\(p\)表示\(x\)选较小权值的概率。

由于每个点的状态数只有子树中的叶子个数,可以考虑线段树合并来优化这一DP过程。

merge(k1,k2,l,r)函数中传进去几个参数:\(v1\)\(v2\)分别在\([1,l-1]\)的权值和&在\([r+1,n]\)的权值和。当某一个节点为0时就整个子树打乘法tag。

代码

#include<bits/stdc++.h>
clock_t t=clock();
namespace my_std{
    using namespace std;
    #define pii pair<int,int>
    #define fir first
    #define sec second
    #define MP make_pair
    #define rep(i,x,y) for (int i=(x);i<=(y);i++)
    #define drep(i,x,y) for (int i=(x);i>=(y);i--)
    #define go(x) for (int i=head[x];i;i=edge[i].nxt)
    #define templ template<typename T>
    #define sz 303300
    #define mod 998244353ll
    typedef long long ll;
    typedef double db;
    mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
    templ inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
    templ inline bool chkmax(T &x,T y){return x<y?x=y,1:0;}
    templ inline bool chkmin(T &x,T y){return x>y?x=y,1:0;}
    templ inline void read(T& t)
    {
        t=0;char f=0,ch=getchar();double d=0.1;
        while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
        while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
        if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
        t=(f?-t:t);
    }
    template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
    char __sr[1<<21],__z[20];int __C=-1,__zz=0;
    inline void Ot(){fwrite(__sr,1,__C+1,stdout),__C=-1;}
    inline void print(register int x)
    {
        if(__C>1<<20)Ot();if(x<0)__sr[++__C]='-',x=-x;
        while(__z[++__zz]=x%10+48,x/=10);
        while(__sr[++__C]=__z[__zz],--__zz);__sr[++__C]='\n';
    }
    void file()
    {
        #ifdef NTFOrz
        freopen("a.in","r",stdin);
        #endif
    }
    inline void chktime()
    {
        #ifndef ONLINE_JUDGE
        cout<<(clock()-t)/1000.0<<'\n';
        #endif
    }
    #ifdef mod
    ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;return ret;}
    ll inv(ll x){return ksm(x,mod-2);}
    #else
    ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;return ret;}
    #endif
//	inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std;
inline ll MOD(ll x){return x-((mod-x)>>31&mod);}

int n;
int w[sz],_w[sz];
struct hh{int t,nxt;}edge[sz<<1];
int head[sz],ecnt;
void make_edge(int f,int t){edge[++ecnt]=(hh){t,head[f]};head[f]=ecnt;}
int son[sz];

int root[sz],cc;
#define Tree sz*30
int ls[Tree],rs[Tree];
ll val[Tree],sum[Tree],tag[Tree];
#define lson ls[k],l,mid
#define rson rs[k],mid+1,r
void Mul(int k,ll w){tag[k]=tag[k]*w%mod;sum[k]=sum[k]*w%mod;val[k]=val[k]*w%mod;}
void pushup(int k){sum[k]=MOD(sum[ls[k]]+sum[rs[k]]);}
void pushdown(int k)
{
	if (tag[k]==1) return;
	ls[k]&&(Mul(ls[k],tag[k]),0);
	rs[k]&&(Mul(rs[k],tag[k]),0);
	tag[k]=1;
}
int merge(int k1,int k2,int l,int r,ll pre1,ll suf1,ll pre2,ll suf2,ll p)
{
	if (k1+k2==0) return 0;
	if (!k1||!k2)
	{
		if (!k1) swap(k1,k2),swap(pre1,pre2),swap(suf1,suf2);
		Mul(k1,MOD(p*suf2%mod+(mod+1-p)*pre2%mod));
		return k1;
	}
	pushdown(k1);pushdown(k2);
	int mid=(l+r)>>1,k=++cc;tag[k]=1;
	ll L1=sum[ls[k1]],R1=sum[rs[k1]],L2=sum[ls[k2]],R2=sum[rs[k2]];
	ls[k]=merge(ls[k1],ls[k2],l,mid,pre1,MOD(suf1+R1),pre2,MOD(suf2+R2),p);
	rs[k]=merge(rs[k1],rs[k2],mid+1,r,MOD(pre1+L1),suf1,MOD(pre2+L2),suf2,p);
	pushup(k);
	return k;
}
void insert(int &k,int l,int r,int x)
{
	k=++cc;tag[k]=1;
	if (l==r) return (void)(sum[k]=val[k]=1);
	int mid=(l+r)>>1;
	if (x<=mid) insert(ls[k],l,mid,x);
	else insert(rs[k],mid+1,r,x);
	pushup(k);
}
ll ans;
void getans(int k,int l,int r)
{
	if (!k) return;
	if (l==r) return (void)((ans+=1ll*l*_w[l]%mod*val[k]%mod*val[k]%mod)%=mod);
	int mid=(l+r)>>1;
	pushdown(k);
	getans(ls[k],l,mid);getans(rs[k],mid+1,r);
}

void dfs(int x)
{
	if (!son[x]) return insert(root[x],1,n,w[x]);
	int v1=0,v2=0;
	go(x) dfs(edge[i].t),(v1?v2:v1)=edge[i].t;
	if (!v2) root[x]=root[v1];
	else root[x]=merge(root[v1],root[v2],1,n,0,0,0,0,w[x]);
}

int main()
{
    file();
	read(n);
	int x;read(x);
	rep(i,2,n) read(x),make_edge(x,i),son[x]++;
	x=0;
	rep(i,1,n) { read(w[i]); if (!son[i]) _w[++x]=w[i]; }
	sort(_w+1,_w+x+1);x=unique(_w+1,_w+x+1)-_w-1;
	rep(i,1,n) 
		if (!son[i]) w[i]=lower_bound(_w+1,_w+x+1,w[i])-_w; 
		else w[i]=(10000-w[i])*inv(10000)%mod;
	dfs(1);
	getans(root[1],1,n);
	cout<<ans;
	return 0;
}
posted @ 2019-05-16 14:46  p_b_p_b  阅读(289)  评论(0编辑  收藏  举报