BZOJ 4919: [Lydsy1706月赛]大根堆
F[x][i]表示x的子树中取的数字<=i的最大值,线段树合并优化DP
写得很难看,并不知道好看的写法
#include<cstdio> #include<algorithm> using namespace std; int cnt,n,Num,ans,last[200005],tag[10000005],ls[10000005],rs[13000005],E[200005],a[200005],Fa[200005],root[200005],ANS[200005],tree[10000005]; struct node{ int to,next; }e[1000005]; void add(int a,int b){ e[++cnt].to=b; e[cnt].next=last[a]; last[a]=cnt; } void push_down(int x){ if (ls[x]) tag[ls[x]]+=tag[x],tree[ls[x]]+=tag[x],tree[ls[x]]=max(tree[ls[x]],tree[x]); if (rs[x]) tag[rs[x]]+=tag[x],tree[rs[x]]+=tag[x],tree[rs[x]]=max(tree[rs[x]],tree[x]); tag[x]=0; } int merge(int x,int y){ if (!x) return y; if (!y) return x; push_down(x); push_down(y); if (!ls[x]) ls[x]=ls[y],tree[ls[x]]+=tree[x],tag[ls[x]]+=tree[x]; else if (!ls[y]) tree[ls[x]]+=tree[y],tag[ls[x]]+=tree[y]; else ls[x]=merge(ls[x],ls[y]); if (!rs[x]) rs[x]=rs[y],tree[rs[x]]+=tree[x],tag[rs[x]]+=tree[x]; else if (!rs[y]) tree[rs[x]]+=tree[y],tag[rs[x]]+=tree[y]; else rs[x]=merge(rs[x],rs[y]); tree[x]+=tree[y]; return x; } int query(int t,int l,int r,int x){ if (!t) return 0; if (l==r) return tree[t]; push_down(t); int mid=(l+r)>>1; if (x<=mid) return max(tree[t],query(ls[t],l,mid,x)); else return max(tree[t],query(rs[t],mid+1,r,x)); } void insert(int &t,int l,int r,int x,int y,int Val){ if (l>y || r<x) return; if (!t) t=++cnt; if (l>=x && r<=y){ tree[t]=max(tree[t],Val); return; } push_down(t); int mid=(l+r)>>1; insert(ls[t],l,mid,x,y,Val); insert(rs[t],mid+1,r,x,y,Val); } void solve(int x){ for (int i=last[x]; i; i=e[i].next){ int V=e[i].to; solve(V); root[x]=merge(root[x],root[V]); } int Key=query(root[x],1,Num,a[x]-1); insert(root[x],1,Num,a[x],Num,Key+1); } int main(){ scanf("%d",&n); for (int i=1; i<=n; i++) { scanf("%d%d",&a[i],&Fa[i]); if (Fa[i]) add(Fa[i],i); } E[++n]=-1e9; for (int i=1; i<=n; i++) E[i]=a[i]; sort(E+1,E+n+1); Num=unique(E+1,E+n+1)-E-1; for (int i=1; i<=n; i++) a[i]=lower_bound(E+1,E+Num+1,a[i])-E; cnt=0; solve(1); for (int i=1; i<=Num; i++) ANS[i]=query(root[1],1,Num,i); int ans=0; for (int i=1; i<=Num; i++) ans=max(ans,ANS[i]); printf("%d\n",ans); return 0; }