CF490F Treeland Tour 题解
线段树合并维护子树信息
Statement
CF490F Treeland Tour - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
给出一棵带点权树,求树上最长上升子序列的长度
\(n\le 6\times 10^3\)
Solution
显然我们有 \(O(n^2\log n)\) 暴力对于每一个点做一次 LIS 的方法可以艹过去
它的核心代码大概长这样:
void dfs(int k,int fa)
{
int pos=lower_bound(f+1,f+n+1,val[k])-f,tmp=f[pos];
f[pos]=val[k],ans=max(ans,pos);
for (int i=head[k];i;i=edge[i].nxt)
{
int v=edge[i].to;
if (v!=fa) dfs(v,k);
}
f[pos]=tmp;
}
for(int i=1;i<=n;++i)dfs(i,0);
考虑 \(n\le 10^5\) ,上面暴力求 LIS 的思路存在瓶颈,我们考虑换一种方式算
另外一个思路是,对于每个点 \(u\) ,维护以子树内的点为结尾的 \(LIS\) 和 \(LDS\) 信息,考虑点 \(u\) 贡献来源
- 经过该点的最长上升子序列
- 不经过该点的最长上升子序列(即两颗不同子树间的信息合并)
我们考虑线段树合并,线段树叶子保存以 \(x\) 为结尾的 LIS 和 LDS 的长度
对于第一种贡献,我们可以直接从儿子里面扒出 LIS 和 LDS 的信息,然后再在序列末尾追加一个 \(u\) 节点。也就是查一下子树里面最长的末尾值 \(\le val[u]\) 的序列然后+1 之类。
对于第二种情况,我们直接在线段树合并的时候从合并的两棵线段树里每层的左右子树里面选左边的最长上升子序列,右边的最长下降子序列合并即可,记得交换一下两棵树的顺序再更新一次答案。
这也为我们合并子树信息提供了一种新思路,即在线段树合并的过程中计算路径贡献
Code
#include<bits/stdc++.h>
#define mid ((l+r)>>1)
#define min(a,b) ((a)<(b)?(a):(b))
#define max(a,b) ((a)>(b)?(a):(b))
using namespace std;
const int N = 1e5+5;
char buf[1<<23],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
int read(){
int s=0,w=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
while(isdigit(ch))s=s*10+(ch^48),ch=getchar();
return s*w;
}
struct Node{int x,y;}son[N];
int lc[N*80],rc[N*80],mx[N*80][2];
int val[N],b[N],rot[N];
vector<int>Edge[N];
int n,ans,t,siz;
void pushup(int p){
mx[p][0]=max(mx[lc[p]][0],mx[rc[p]][0]);
mx[p][1]=max(mx[lc[p]][1],mx[rc[p]][1]);
}
void alter(int l,int r,int& p,int id,int v,int op){
if(!p)p=++siz;
if(l==r)return mx[p][op]=max(v,mx[p][op]),void();
id<=mid?alter(l,mid,lc[p],id,v,op):alter(mid+1,r,rc[p],id,v,op);
pushup(p);
}
int query(int l,int r,int p,int L,int R,int op){
if(!p)return 0;
if(L<=l&&r<=R)return mx[p][op];
if(R<=mid)return query(l,mid,lc[p],L,R,op);
if(L>mid)return query(mid+1,r,rc[p],L,R,op);
return max(query(l,mid,lc[p],L,R,op),query(mid+1,r,rc[p],L,R,op));
}
int merge(int l,int r,int p,int q){
ans=max(ans,max(mx[lc[p]][0]+mx[rc[q]][1],mx[rc[p]][1]+mx[lc[q]][0]));
if(!p||!q)return p+q;
if(l==r){
mx[p][0]=max(mx[p][0],mx[q][0]);
mx[p][1]=max(mx[p][1],mx[q][1]);
return p;
}
lc[p]=merge(l,mid,lc[p],lc[q]);
rc[p]=merge(mid+1,r,rc[p],rc[q]);
return pushup(p),p;
}
bool cmpx(Node x,Node y){return x.x<y.x;}
bool cmpy(Node x,Node y){return x.y<y.y;}
void dfs(int u,int fath){
rot[u]=++siz;
int mx0=0,mx1=0,cnt=0;
for(auto v:Edge[u])if(v^fath)dfs(v,u);
for(auto v:Edge[u])if(v^fath){
++cnt;
if(val[u]==1)son[cnt].x=1;
else son[cnt].x=query(1,t,rot[v],1,val[u]-1,0)+1;
if(val[u]==t)son[cnt].y=1;
else son[cnt].y=query(1,t,rot[v],val[u]+1,t,1)+1;
mx0=max(mx0,son[cnt].x),mx1=max(mx1,son[cnt].y);
rot[u]=merge(1,t,rot[u],rot[v]);
}
if(!cnt)ans=max(ans,1),mx0=mx1=1;
if(cnt==1)ans=max(ans,max(son[1].x,son[1].y));
if(cnt>1){
sort(son+1,son+1+cnt,cmpx);
sort(son+1,son+0+cnt,cmpy);
ans=max(ans,son[cnt].x+son[cnt-1].y-1);
sort(son+1,son+1+cnt,cmpy);
sort(son+1,son+0+cnt,cmpx);
ans=max(ans,son[cnt].y+son[cnt-1].x-1);
}
alter(1,t,rot[u],val[u],mx0,0);
alter(1,t,rot[u],val[u],mx1,1);
}
signed main(){
n=read();
for(int i=1;i<=n;++i)val[i]=b[i]=read();
sort(b+1,b+1+n),t=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;++i)val[i]=lower_bound(b+1,b+1+t,val[i])-b;
for(int i=1,u,v;i<n;++i)u=read(),v=read(),Edge[u].push_back(v),Edge[v].push_back(u);
dfs(1,0),printf("%d\n",ans);
return 0;
}