长链剖分学习笔记
长链剖分
长链剖分用于优化一些特殊的dp,可以将某些\(O(n)\)的时间复杂度降为均摊\(O(1)\)。
感觉这玩意儿大部分东西都和树链剖分挺像,理解的时候可以照着轻重链剖分那种去理解
定义
长链
和重链差不多,就是从某一个节点走到它子树中最深的节点所经过的路径
重儿子
某个节点在长链上的儿子就是它的重儿子(很好奇为什么不叫长儿子)。
顶点
某条长链中深度最小的点就是该长链的顶点
不难看出长链剖分和树链剖分其实很像,树链剖分中重儿子所在子树具有最大的size,而长链剖分中重儿子所在子树具有最大的dep
性质
1.所有链长总和为\(O(n)\)
2.任意一个节点\(x\)的\(k\)级祖先\(y\)所在长链长度一定大于等于\(k\)
正确性的话,自己yy一下好了
应用
因为这东西的基础基本和树剖没什么区别所以不多讲了,还是从应用里还说明它的一下特殊性好了
O(1)在线查询某一个点的\(k\)级祖先
我们设\(len[u]\)为\(u\)所在长链的长度,对于每一个长链的顶点,我们维护它的1到\(len[u]\)级儿子以及1到\(len[u]\)级祖先
同时预处理找祖先的倍增数组,以及\(1\)到\(n\)的每一个数字的二进制最高位即\(highbit\)
那么对于每一个询问\((u,k)\),我们设\(r=highbit(k)\),那么我们用预处理的倍增数组让\(u\)跳到它的\(r\)级祖先\(v\)处
因为\(k-r<r\),那么\(v\)的长链的长度\(\geq r>k-r\),那么\(v\)所在的长链预处理的表一定已经包含了\(u\)的\(k\)级祖先
时间复杂度为\(O(nlogn+m)\),预处理\(O(nlogn)\),每一次回答\(O(1)\)
//minamoto
#include<iostream>
#include<cstdio>
#include<vector>
using namespace std;
#define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
int read(){
#define num ch-'0'
char ch;bool flag=0;int res;
while(!isdigit(ch=getc()))
(ch=='-')&&(flag=true);
for(res=num;isdigit(ch=getc());res=res*10+num);
(flag)&&(res=-res);
#undef num
return res;
}
char sr[1<<21],z[20];int C=-1,Z;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
void print(int x){
if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
while(z[++Z]=x%10+48,x/=10);
while(sr[++C]=z[Z],--Z);sr[++C]='\n';
}
const int N=3e5+5;
int head[N],Next[N<<1],ver[N<<1],tot;
inline void add(int u,int v){
ver[++tot]=v,Next[tot]=head[u],head[u]=tot;
}
int n,md[N],dep[N],fa[N][21],son[N],top[N],len[N],B[N];
vector<int> U[N],D[N];
void dfs(int u,int f){
md[u]=dep[u]=dep[f]+1,fa[u][0]=f;
for(int i=1;i<20;++i)
if(fa[u][i-1]) fa[u][i]=fa[fa[u][i-1]][i-1];
else break;
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(v!=f){
dfs(v,u);
if(md[son[u]]<md[v]) son[u]=v,md[u]=md[v];
}
}
}
void dfs2(int u,int t){
top[u]=t,len[u]=md[u]-dep[t]+1;
if(son[u]){
dfs2(son[u],t);
for(int i=head[u];i;i=Next[i])
if(!top[ver[i]]) dfs2(ver[i],ver[i]);
}
}
void init(){
int now=0;
for(int i=1;i<=n;++i){
if(!(i&(1<<now))) ++now;
B[i]=now;
}
for(int i=1;i<=n;++i)
if(i==top[i]){
for(int j=1,u=i;j<=len[i]&&u;++j) u=fa[u][0],U[i].push_back(u);
for(int j=1,u=i;j<=len[i]&&u;++j) u=son[u],D[i].push_back(u);
}
}
int query(int u,int k){
if(k>dep[u]) return 0;if(k==0) return u;
u=fa[u][B[k]],k^=1<<B[k];
if(k==0) return u;
if(dep[u]-dep[top[u]]==k) return top[u];
if(dep[u]-dep[top[u]]<k) return U[top[u]][k-dep[u]+dep[top[u]]-1];
else return D[top[u]][dep[u]-dep[top[u]]-k-1];
}
int main(){
// freopen("testdata.in","r",stdin);
n=read();
for(int i=1;i<n;++i){
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs(1,0),dfs2(1,1),init();
int lastans=0,q=read();
while(q--){
int u=read()^lastans,v=read()^lastans;
printf("%d\n",lastans=query(u,v));
}
return Ot(),0;
}
O(n)合并以深度为下标的信息
这个比较类似于dsu on tree了,对于一些和深度有关的信息,我们可以直接继承它重儿子的信息,对于轻儿子的暴力统计。比方说\(f_{u,i}\)表示到\(u\)的子树中到它的距离为\(i\)的点的个数,\(v\)为\(u\)的重儿子,那么可以直接\(f_{u,i}=f_{v,i-1}\),然后轻儿子的暴力统计
时间复杂度证明的话,因为每个点只会暴力统计它的轻儿子,所以每一个点也只会在从它到根节点的第一条轻边上被统计一遍,于是总的复杂度是\(O(n)\)
对于每个点,只需要开正比于它所在长链长度的空间,总的空间复杂度也是\(O(n)\)
具体实现的话,可以用指针来维护,每一次合并的之后只要把指针给移位即可
给你一棵树,定义\(d_{x,i}\)表示\(x\)子树内和\(x\)距离为\(i\)的节点数,对每个\(x\)求使\(d_{x,i}\)最大的\(i\),如有多个输出最小的。
首先我们得求出对于每个点,它子树中距离它不同距离的点有多少,很显然\(f_{u,i}=\sum f_{v,i-1}\)。然而时空都得炸。于是让每个点继承重儿子的信息,对轻儿子的信息暴力统计,对于每一条长链,发现开的数组最大是这条长链的长度,其他都没有用。于是总的空间复杂度即为长链的长度之和,为\(O(n)\)
//minamoto
#include<bits/stdc++.h>
using namespace std;
#define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
int read(){
int res,f=1;char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
return res*f;
}
char sr[1<<21],z[20];int C=-1,Z=0;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
void print(int x){
if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x;
while(z[++Z]=x%10+48,x/=10);
while(sr[++C]=z[Z],--Z);sr[++C]='\n';
}
const int N=1e6+5;
int head[N],Next[N<<1],ver[N<<1],tot;
inline void add(int u,int v){ver[++tot]=v,Next[tot]=head[u],head[u]=tot;}
int len[N],son[N],tmp[N],*f[N],*id=tmp,ans[N],n;
void dfs(int u,int fa){
for(int i=head[u];i;i=Next[i])if(ver[i]!=fa){
dfs(ver[i],u);
if(len[ver[i]]>len[son[u]])son[u]=ver[i];
}
len[u]=len[son[u]]+1;
}
void dp(int u,int fa){
f[u][0]=1;if(son[u])f[son[u]]=f[u]+1,dp(son[u],u),ans[u]=ans[son[u]]+1;
for(int i=head[u];i;i=Next[i]){
int v=ver[i];if(v==fa||v==son[u])continue;
f[v]=id,id+=len[v],dp(v,u);
for(int j=1;j<=len[v];++j){
f[u][j]+=f[v][j-1];
if((j<ans[u]&&f[u][j]>=f[u][ans[u]])||(j>ans[u]&&f[u][j]>f[u][ans[u]]))
ans[u]=j;
}
}
if(f[u][ans[u]]==1)ans[u]=0;
}
int main(){
// freopen("testdata.in","r",stdin);
n=read();
for(int i=1,u,v;i<n;++i)u=read(),v=read(),add(u,v),add(v,u);
dfs(1,0),f[1]=id,id+=len[1],dp(1,0);
for(int i=1;i<=n;++i)print(ans[i]);
return Ot(),0;
}
以及一些例题
参考文章