【UOJ #284】— 快乐游戏鸡(长链剖分+线段树)

传送门

可以发现我们大致是要维护一个每个深度的最大值之类的东西
考虑离线下来对每个点处理所有子树内的询问

f[i][j]f[i][j]表示点ii子树深度为jj以内的最大值是多少
那么f[i][j]f[i][j1]f[i][j]-f[i][j-1]就是深度jj走的次数
那么对于一个询问(s,t)(s,t)
答案就是i=1d=dep[pos[mxval]]dep[s](f[s][i]f[s][i1])i\sum_{i=1}^{d=dep[pos[mxval]]-dep[s]}(f[s][i]-f[s][i-1])*i
=f[s][d]di=1d1f[s][i]=f[s][d]*d-\sum_{i=1}^{d-1}f[s][i]

由于和深度有关考虑长链剖分维护这样一个ff
考虑合并2条链,即每次加入一个vv更新ff
显然vv是要和一段后缀取maxmax
由于ff是单调的所以可以二分出第一个大于vv的位置覆盖就可以了

代码简单好写

#include<bits/stdc++.h>
using namespace std;
const int RLEN=1<<20|1;
inline char gc(){
    static char ibuf[RLEN],*ib,*ob;
    (ob==ib)&&(ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
    return (ob==ib)?EOF:*ib++;
}
#define gc getchar
inline int read(){
    char ch=gc();
    int res=0,f=1;
    while(!isdigit(ch))f^=ch=='-',ch=gc();
    while(isdigit(ch))res=(res+(res<<2)<<1)+(ch^48),ch=gc();
    return f?res:-res;
}
#define ll long long
#define re register
#define pii pair<int,int>
#define fi first
#define se second
#define pb push_back
#define cs const
#define bg begin
#define poly vector<int>	
inline void chemx(int &a,int b){a<b?a=b:0;}
inline void chemn(int &a,int b){a>b?a=b:0;}
cs int N=300005;
int n,m,dep[N],top[N],fa[N][20],mx[N][20],mxdep[N],son[N],val[N];
struct Seg{
	vector<int> mx,tag;
	vector<ll> s;int n;
	inline void build(int len){
		n=len-1,mx.resize(len<<2,0),tag.resize(len<<2,0),s.resize(len<<2,0);
	}
	#define lc (u<<1)
	#define rc (u<<1|1)
	#define mid ((l+r)>>1)
	inline void pushup(int u){
		s[u]=s[lc]+s[rc];
		mx[u]=max(mx[lc],mx[rc]);
	}
	inline void pushnow(int u,int len,int k){
		tag[u]=mx[u]=k,s[u]=1ll*k*len;
	}
	inline void pushdown(int u,int l,int r){
		if(!tag[u])return;
		pushnow(lc,mid-l+1,tag[u]);
		pushnow(rc,r-mid,tag[u]);
		tag[u]=0;
	}
	void update(int u,int l,int r,int st,int des,int k){
		if(st<=l&&r<=des)return pushnow(u,r-l+1,k);
		pushdown(u,l,r);
		if(st<=mid)update(lc,l,mid,st,des,k);
		if(mid<des)update(rc,mid+1,r,st,des,k);
		pushup(u);
	}
	inline int div(int u,int l,int r,int st,int des,int k){
		if(l==r){
			if(mx[u]<k)return des+1;
			return l;
		}
		pushdown(u,l,r);int res=0;
		if(mx[u]<k)res=des+1;
		else if(mx[lc]>=k)res=div(lc,l,mid,st,des,k);
		else res=div(rc,mid+1,r,st,des,k);
		pushup(u);return res;
	}
	ll query(int u,int l,int r,int st,int des){
		if(st<=l&&r<=des)return s[u];
		pushdown(u,l,r);
		ll res=0;
		if(st<=mid)res+=query(lc,l,mid,st,des);
		if(mid<des)res+=query(rc,mid+1,r,st,des);
		pushup(u);return res;
	}
	inline int find(int beg,int val){
		return div(1,0,n,beg,n,val);
	}
	inline void insert(int pos,int val){
		int ps=find(pos,val);
		if(ps>pos)update(1,0,n,pos,ps-1,val);
	}
	inline ll query(int l,int r){
		return query(1,0,n,l,r);	
	}
	#undef lc 
	#undef rc
	#undef mid
}tr[N];
struct node{
	int des,val,id;
	node(int _d=0,int _v=0,int _i=0):des(_d),val(_v),id(_i){}
};
vector<int> e[N];
vector<node> q[N];
ll ans[N];
void dfs1(int u){
	mx[u][0]=val[fa[u][0]];
	for(int i=1;i<=19;i++)fa[u][i]=fa[fa[u][i-1]][i-1],mx[u][i]=max(mx[u][i-1],mx[fa[u][i-1]][i-1]);
	for(int &v:e[u]){
		dep[v]=dep[u]+1,dfs1(v);
		if(mxdep[v]>mxdep[son[u]])son[u]=v;
	}
	mxdep[u]=mxdep[son[u]]+1;
}
void dfs2(int u,int tp){
	top[u]=tp;if(u==tp)tr[u].build(mxdep[u]);
	if(son[u])dfs2(son[u],tp);
	for(int &v:e[u])if(v!=son[u])dfs2(v,v);
}
inline int Lca(int u,int v){
	int res=0,d=dep[v]-dep[u]-1;
	for(int i=19;~i;i--)if(d&(1<<i))chemx(res,mx[v][i]),v=fa[v][i];
	return res;
}
inline void merge(int u,int v){
	for(int i=0;i<mxdep[v];i++){
		tr[top[u]].insert(dep[v]-dep[top[u]]+i,tr[v].query(i,i));
	}
}
void dp(int u){
	if(son[u])dp(son[u]);
	for(int &v:e[u]){
		if(v==son[u])continue;
		dp(v),merge(u,v);
	}
	for(node &x:q[u]){
		int des=x.des,val=x.val;ll res=0;
		int pos=tr[top[u]].find(dep[u]-dep[top[u]]+1,val)-(dep[u]-dep[top[u]]);
		if(pos>1)res-=tr[top[u]].query(dep[u]-dep[top[u]]+1,dep[u]-dep[top[u]]+pos-1);
		res+=1ll*pos*val+dep[des]-dep[u];ans[x.id]=res;
	}
	tr[top[u]].insert(dep[u]-dep[top[u]],val[u]);
}
int main(){
	n=read();
	for(int i=1;i<=n;i++)val[i]=read();
	for(int i=2;i<=n;i++)fa[i][0]=read(),e[fa[i][0]].pb(i);
	dfs1(1),dfs2(1,1);
	m=read();
	for(int i=1;i<=m;i++){
		int s=read(),t=read();
		if(s==t){ans[i]=0;continue;}
		int val=Lca(s,t);
		q[s].pb(node(t,val,i));
	}
	dp(1);
	for(int i=1;i<=m;i++)cout<<ans[i]<<'\n';
}
posted @ 2019-09-17 18:16  Stargazer_cykoi  阅读(110)  评论(0编辑  收藏  举报