长链剖分小结
长链剖分是一种类似\(\rm{dsu\ on\ tree}\)的一种算法,写法类似于普通的树链剖分(重链剖分),只是将\(\rm{siz}\)最大改为了\(\rm{dep}\)最大.可以优化一些与子树深度相关的问题的时间.
性质
1、所有链的长度和为\(O(n)\)级别的
所有的点均只会在一条长链里,所以都只会被计算一次,所以是\(O(n)\)级别的
2、父亲所在的长链长度不会小于其儿子所在的长链
如果上述不成立的话,那么父亲点可以选择该儿子使得长链更长,与原来相矛盾.
这个性质有个推论:对于任何一个点,其\(k\)次祖先所在的长链必然大于等于\(k\),证明类似.
3、从某个点出发向上跳,切换长链的次数是\(O(\sqrt n)\)级别的
根据性质\(2\),每次跳的长链长度一定不会小于上一次的,即最坏情况索所跳的长链长度为\(1,2,3,\cdots\),也就是跳了\(O(\sqrt n)\)次
实现及例题
长链剖分的实现中的第一个dfs类似于重链剖分.
void dfs1(int u,int fu)
{
for (int i=head[u];i;i=sq[i].nxt)
{
int v=sq[i].to;
if (v==fu) continue;
dfs1(v,u);
if (len[son[u]]<len[v]) son[u]=v;
}
len[u]=len[son[u]]+1;
}
其中\(len\)记录的是\(u\)所在的长链在\(u\)子树中的长度.
1、求\(k\)级祖先
普通的想法是倍增,可以做到\(O(n\log n)\)预处理\(O(\log n)\)询问,看起来很优秀,但是可以做到更好.
来考虑一下长链剖分,对每条长链及其长度为\(m\),预处理处其链顶向上的\(k\)个祖先和向下的\(k\)个重儿子,由性质1知这是\(O(n)\)的.之后对于每个询问我们先跳到这个点的\(x\)级祖先上,保证这个祖先所在长链的链长\(>k-x\),之后再跳到这条链的链顶上,根据预处理的结果求出答案.
如何保证链长\(>k-x\)呢?发现对每个\(k\),记其二进制表示下最高位为第\(h_k\)位,那么令\(x=2^{h_k}\)即可,这样的话和倍增一样的预处理出那个数组即可.总的是\(O(n\log n)\)预处理\(O(1)\)查询.
例题:luogu5903
#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<math.h>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
const int N=100000;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 998244353
#define eps 1e-8
struct node{int to,nxt;}sq[1001000];
int all=0,head[500500];
int n,dep[500500],mx[500500],son[500500],fa[500500][20],tp[500500],rt,hbit[500500],q;
vector<int> U[500500],D[500500];
inline int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
}
#define ui unsigned int
ui s;
inline ui get(ui x) {
x ^= x << 13;
x ^= x >> 17;
x ^= x << 5;
return s = x;
}
void add(int u,int v)
{
all++;sq[all].to=v;sq[all].nxt=head[u];head[u]=all;
}
void dfs1(int u,int fu)
{
dep[u]=dep[fu]+1;mx[u]=dep[u];fa[u][0]=fu;
rep(i,1,19) fa[u][i]=fa[fa[u][i-1]][i-1];
go(u,i)
{
int v=sq[i].to;
if (v==fu) continue;
dfs1(v,u);
if (mx[v]>mx[u]) {son[u]=v;mx[u]=mx[v];}
}
}
void dfs2(int u,int tpu)
{
tp[u]=tpu;
if (u==tpu)
{
int now=u;
rep(i,0,mx[u]-dep[u])
{
D[u].pb(now);
now=son[now];
}
now=u;
rep(i,0,mx[u]-dep[u])
{
U[u].pb(now);
now=fa[now][0];
}
}
if (son[u]) dfs2(son[u],tpu);
go(u,i)
{
int v=sq[i].to;
if ((v==fa[u][0]) || (v==son[u])) continue;
dfs2(v,v);
}
}
int query(int x,int k)
{
if (!k) return x;
x=fa[x][hbit[k]];k-=(1<<hbit[k]);
//cout << "half " << x << " " << k << endl;
k-=(dep[x]-dep[tp[x]]);x=tp[x];
//cout << "now " << x << " " << k << endl;
if (k>=0) return U[x][k];else return D[x][-k];
}
int main()
{
n=read();q=read();s=read();
rep(i,1,n)
{
int fa=read();
add(fa,i);add(i,fa);
if (!fa) rt=i;
}
rep(i,2,n) hbit[i]=hbit[i>>1]+1;
dfs1(rt,0);dfs2(rt,rt);
int ans=0;ll fin=0;
rep(i,1,q)
{
int x=(get(s)^ans)%n+1,k=(get(s)^ans)%dep[x];
ans=query(x,k);
fin^=(1ll*i*ans);
}
printf("%lld",fin);
return 0;
}
2、优化某些dp
有些dp的状态形如\(f_{u,i}\),其中\(i\)这一维只与深度有关。对于这样的dp我们可以使用长链剖分进行优化。具体的,我们先做\(u\)的重儿子,之后再将所有的轻儿子的答案合并到这上面去。如果我们合并的时候可以做到\(O(len)\)合并,那么总的时间复杂度就是\(O(\sum len)\),也就是\(O(n)\)
先把暴力dp的式子写起来:\(f_{u,i}\)表示在\(u\)的子树中距离\(u\)等于\(i\)的点的个数,那么有
第二维的信息只和深度有关,于是可以用长链剖分来优化dp
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<math.h>
#include<queue>
#include<set>
using namespace std;
typedef long long ll;
typedef long double db;
typedef pair<int,int> pii;
const int N=100000;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define maxd 998244353
#define eps 1e-8
struct node{int to,nxt;}sq[2002000];
int all=0,head[1001000];
int n,son[1001000],ans[1001000],*f[1001000],tmp[1001000],*id=tmp,len[1001000];
inline int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
}
void add(int u,int v)
{
all++;sq[all].to=v;sq[all].nxt=head[u];head[u]=all;
}
void dfs1(int u,int fu)
{
go(u,i)
{
int v=sq[i].to;
if (v==fu) continue;
dfs1(v,u);
if (len[son[u]]<len[v]) son[u]=v;
}
len[u]=len[son[u]]+1;
}
void dfs2(int u,int fu)
{
f[u][0]=1;
if (son[u])
{
f[son[u]]=f[u]+1;
dfs2(son[u],u);
ans[u]=ans[son[u]]+1;
}
go(u,i)
{
int v=sq[i].to;
if ((v==fu) || (v==son[u])) continue;
f[v]=id;id+=len[v];dfs2(v,u);
rep(j,1,len[v])
{
f[u][j]+=f[v][j-1];
if (((j<ans[u]) && (f[u][j]>=f[u][ans[u]])) || ((j>ans[u]) && (f[u][j]>f[u][ans[u]])))
ans[u]=j;
}
}
if (f[u][ans[u]]==1) ans[u]=0;
}
int main()
{
n=read();
rep(i,1,n-1)
{
int u=read(),v=read();
add(u,v);add(v,u);
}
dfs1(1,0);
f[1]=id;id+=len[1];
dfs2(1,0);
rep(i,1,n) printf("%d\n",ans[i]);
return 0;
}