蒟蒻的长链剖分学习笔记(例题:HOTEL加强版、重建计划)
长链剖分学习笔记
说到树的链剖,大多数人都会首先想到重链剖分。的确,目前重链剖分在OI中有更加多样化的应用,但它大多时候是替代不了长链剖分的。
重链剖分是把size最大的儿子当成重儿子,顾名思义长链剖分就是把 len (到叶子节点的距离) 最长的儿子当成重儿子。
由于是和深度有关的算法,长链剖分常用于优化一些和深度有关的dp或其他算法。
具体按照蒟蒻的理解来说,就是类似启发式合并的那种感觉,每个点为根的子树都可以看成一条最长的链上支出了一些叉,而我们想把这棵子树捋成只有一条链,毕竟链多清晰明了,还好合并。于是,我们先递归下去重儿子,处理出子树中这条最长的链的信息,然后再枚举轻儿子递归下去,反上来的时候把以这个轻儿子开头的链信息合并到最长链上。
那么问题来了,
什么样的问题能用长链剖分解决呢?
蒟蒻以为,那种与路径长度有关,且能以每个深度为状态,相同深度之间更新没有顺序的问题应该就可以吧。
长链剖分的时空复杂度怎么证明?
可以发现,每次向上合并都是拿一条儿子子树的最长链合并到父亲子树的最长链上,时空代价都是
这里给出两道例题
1、bzoj4543 HOTEL (data_stronger)
注意这里的dp本来是要开二维数组的(f(u)(j)表示u这棵子树,距离u为j的点有多少个),想一想,为什么空间并不会炸呢?
可以发现父亲节点u的 f 数组一开始就是重儿子节点v的 f 数组右移一位得到的,所以我们可以最开始把重儿子 f 数组的指针赋成 f[u]+1.
对于g也是同理。
具体实现方式请看代码及注释吧。
#include<bits/stdc++.h>
using namespace std;
#define N 200005
#define ll long long
int n,gr,h[N],nxt[N],to[N];
inline void tu(int x,int y){to[++gr]=y,nxt[gr]=h[x],h[x]=gr;}
int len[N]{-1},son[N];
void dfs1(int u,int fa){
len[u]=0;
for(int i=h[u];i;i=nxt[i]){
if(to[i]==fa)continue;
dfs1(to[i],u);
if(len[son[u]]<len[to[i]])son[u]=to[i],len[u]=len[son[u]]+1;
}
}
ll mry[N*3],*f[N],*g[N],*tot=mry,ans;//sor
//
//g(i)表示在子树外离当前点距离为i的点可以和子树内多少对点组成答案
inline void get_memory(int u,int siz){
f[u]=tot;tot+=(siz<<1)+2;g[u]=tot;tot+=siz+2;//+2是必需的
}
void dfs2(int u,int fa){
if(son[u]){
f[son[u]]=f[u]+1;g[son[u]]=g[u]-1;
dfs2(son[u],u);
}
f[u][0]=1;
ans+=f[u][0]*g[u][0];//g[u][0]实际上==g[son[u]][1]
for(int i=h[u];i;i=nxt[i]){
int d=to[i];
if(d==fa||d==son[u])continue;
get_memory(d,len[d]);
dfs2(d,u);
for(int j=0;j<=len[d];++j){
ans+=f[d][j]*g[u][j+1];
if(j)ans+=g[d][j]*f[u][j-1];
}
for(int j=0;j<=len[d];++j){
g[u][j+1]+=f[u][j+1]*f[d][j];
if(j)g[u][j-1]+=g[d][j];//折了一下
f[u][j+1]+=f[d][j];
}
}
}
int main(){
scanf("%d",&n);
for(int i=1,x,y;i<n;++i)scanf("%d%d",&x,&y),tu(x,y),tu(y,x);
dfs1(1,0);
get_memory(1,len[1]);
dfs2(1,0);
printf("%lld\n",ans);
return 0;
}
2、bzoj1758 重建计划
一眼看出来的部分略去不说,我们需要解决的就是求max{边数在L和R之间的路径边权和}.
由于转移时要在一段区间取max,这时候我们发现我们转移的时候需要一棵线段树,像上一道题一样动态用指针开内存的方式显然不能接受。
我们选择另一种方式,即有顺序地遍历全树获得dfs序,使得每条长链上的点dfs序是一段连续的区间。
这样,dfs序上在不超过len(长链)的范围下加多少,就是深度向下走多少。便可以轻松转移了。
#include<bits/stdc++.h>
using namespace std;
#define il inline
#define rep(i,a,b) for(register int i=(a);i<=(b);++i)
#define dwn(i,a,b) for(register int i=(a);i>=(b);--i)
#define lc (x<<1)
#define rc (x<<1|1)
typedef double db;
typedef long long ll;
const int N = 200005;
const db eps = 1e-5;
int n,gr,h[N],nxt[N],to[N],w[N],lwr,upp;
inline void tu(int x,int y,int v){to[++gr]=y,nxt[gr]=h[x],h[x]=gr,w[gr]=v;}
int len[N],son[N];//
int dfn[N],tim,fa[N],sonw[N];
struct Tre{
db tr[N<<2];
void modify(int p,int L,int R,db v,int x){
if(L==p&&R==p){tr[x]=max(tr[x],v);return;}//注意要不断取max! 而不是赋值!
int mid=(L+R)>>1;
if(p<=mid)modify(p,L,mid,v,lc);
else modify(p,mid+1,R,v,rc);
tr[x]=max(tr[lc],tr[rc]);
}
db query(int l,int r,int L,int R,int x){
if(l==L&&r==R)return tr[x];
int mid=(L+R)>>1;
if(r<=mid)return query(l,r,L,mid,lc);
else if(l>mid)return query(l,r,mid+1,R,rc);
else return max(query(l,mid,L,mid,lc),query(mid+1,r,mid+1,R,rc));
}
}T;
void dfs1(int u,int f){
len[u]=0;fa[u]=f;
for(int i=h[u];i;i=nxt[i]){
if(to[i]==f)continue;
dfs1(to[i],u);
if(len[son[u]]<len[to[i]])son[u]=to[i],len[u]=len[son[u]]+1,sonw[u]=w[i];
}
}
db ans;
void dfs2(int u){
dfn[u]=++tim;
if(!son[u])return;
dfs2(son[u]);
for(int i=h[u];i;i=nxt[i]){
if(to[i]==fa[u]||to[i]==son[u])continue;
dfs2(to[i]);
}
}
db dis[N],now[N];
void solve(int u,db x){
T.modify(dfn[u],1,n,dis[u],1);
if(son[u]){
dis[son[u]]=dis[u]+sonw[u]-x;
solve(son[u],x);
}
for(int i=h[u];i;i=nxt[i]){
if(to[i]==son[u]||to[i]==fa[u])continue;
int d=to[i];
dis[d]=dis[u]+w[i]-x;
solve(d,x);
rep(j,0,len[d]){
now[j]=T.query(dfn[d]+j,dfn[d]+j,1,n,1);
if(j+1<=upp&&j+len[u]+1>=lwr){
ans=max(ans,now[j]+T.query(dfn[u]+max(0,lwr-j-1),dfn[u]+min(upp-j-1,len[u]),1,n,1)-2*dis[u]);
}
}
rep(j,0,len[d]){
T.modify(dfn[u]+j+1,1,n,now[j],1);
}
}
if(len[u]>=lwr)
ans=max(ans,T.query(dfn[u]+lwr,dfn[u]+min(len[u],upp),1,n,1)-dis[u]);
//
}
bool check(db x){
memset(T.tr,0xc2,sizeof(T.tr));
ans=-1e7;
solve(1,x);
return (ans>=0);
}
int main(){
scanf("%d",&n);
scanf("%d%d",&lwr,&upp);
int a,b,c,mx=0;
rep(i,1,n-1)scanf("%d%d%d",&a,&b,&c),tu(a,b,c),tu(b,a,c),mx=max(mx,c);
len[0]=-1;
dfs1(1,0),dfs2(1);
db l=0,r=mx,mid;
while(l+eps<r){
mid=(l+r)/2.0;
if(check(mid))l=mid;
else r=mid;
}
if(check(r))printf("%.3lf\n",r);
else printf("%.3lf\n",l);
return 0;
}