树链剖分
树链剖分
用途
- 将任意一条树上路径转化为不超过 l o g n logn logn段区间
- 将一棵树转化成一个序列
这样我们求树上路径和就和以用线段树求
复杂度 O ( l o g 2 n ) O(log^2n) O(log2n)
求法&概念
求出每一棵子树的节点个数
其中某个点的,所在子树的节点个数最多的那个儿子,被称为该点的重儿子
其余的被称为轻儿子
子树大小相等时,重儿子可以任选
我们预处理出每个子树的size,并求出来所有重儿子
我们引入几个概念:
- 重边:连接一个重儿子和它的父节点的边
- 轻边: 其余的边
- 重链:极大地由重边构成的路径
值得注意的是,每个点都应该在一条重链里
一个点在哪条重链里呢?
- 重儿子: 父节点所在的重链里
- 轻儿子:以这个点开头,往下走的重链里
核心结论:树中任意一条路径,均可拆分成 O ( l o g n ) O(logn) O(logn)条重链(不一定完整)
我们如何将一棵树变成一条序列呢?使用特殊的DFS序
在DFS中,我们优先遍历该点的重儿子
即可保证每条重链上所有点的编号是连续的
这样我们就把一条路径转化成 l o g n logn logn段连续区间
一些实现细节:
然后考虑如何将一条路径转化成重链
这个做法有点类似LCA的求法
对于两个点 x , y x,y x,y,我们分别找出他们所在重链的顶点
然后对于深度较大的顶点,找到它父亲所在的重链,接着进行类似的过程
直到所在重链的顶点相同,然后我们把重链重合的那一段抠出来,就得到了路径所经过的重链
例题
我们将对树上路径上的操作转化为连续区间的操作
然后使用线段树进行区间修改,区间查询
/*************************************************************************
> File Name: p3384[模板]轻重链剖分.cpp
> Author: Typedef
> Mail: 1815979752@qq.com
> Created Time: 2021/2/26 12:41:05
> TAG: 树链剖分,线段树,模板
************************************************************************/
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100010,M=N*2;
int n,m;
int cnt=0;
int w[N],h[N],e[M],ne[M],idx;
int id[N],nw[N];//nw:dfs序中第某个点的权值是多少,id表示原来树中某个点dfs序是多少
int dep[N],sz[N],top[N],fa[N],son[N];
struct Tree{
int l,r;
ll add,sum;
}tr[N*4];
void add(int a,int b){
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs1(int u,int father,int depth){
dep[u]=depth,fa[u]=father,sz[u]=1;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==father) continue;
dfs1(j,u,depth+1);
sz[u]+=sz[j];
if(sz[son[u]]<sz[j]) son[u]=j;
}
}
void dfs2(int u,int t){
id[u]=++cnt,nw[cnt]=w[u],top[u]=t;
if(!son[u]) return;
dfs2(son[u],t);
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==fa[u]||j==son[u]) continue;
dfs2(j,j);
}
}
void pushup(int u){
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
}
void pushdown(int u){
auto &root=tr[u],&left=tr[u<<1],&right=tr[u<<1|1];
if(root.add){
left.add+=root.add,left.sum+=root.add*(left.r-left.l+1);
right.add+=root.add,right.sum+=root.add*(right.r-right.l+1);
root.add=0;
}
}
void build(int u,int l,int r){
tr[u]={l,r,0,nw[r]};
if(l==r) return;
int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);
}
void update(int u,int l,int r,int k){
if(l<=tr[u].l&&r>=tr[u].r){
tr[u].add+=k;
tr[u].sum+=k*(tr[u].r-tr[u].l+1);
return;
}
pushdown(u);
int mid=tr[u].l+tr[u].r>>1;
if(l<=mid) update(u<<1,l,r,k);
if(r>mid) update(u<<1|1,l,r,k);
pushup(u);
}
ll query(int u,int l,int r){
if(l<=tr[u].l&&r>=tr[u].r) return tr[u].sum;
pushdown(u);
int mid=tr[u].l+tr[u].r>>1;
ll res=0;
if(l<=mid) res+=query(u<<1,l,r);
if(r>mid) res+=query(u<<1|1,l,r);
return res;
}
void update_path(int u,int v,int k){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
update(1,id[top[u]],id[u],k);//重链连续区间
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
update(1,id[v],id[u],k);
}
ll query_path(int u,int v){
ll res=0;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
res+=query(1,id[top[u]],id[u]);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
res+=query(1,id[v],id[u]);
return res;
}
void update_tree(int u,int k){
update(1,id[u],id[u]+sz[u]-1,k);//根节点的id是最小的,因为最先遍历,同时遍历完子树后才会回溯
}
ll query_tree(int u){
return query(1,id[u],id[u]+sz[u]-1);
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
memset(h,-1,sizeof(h));
for(int i=0;i<n-1;i++){
int a,b;
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
dfs1(1,-1,1);
dfs2(1,1);
build(1,1,n);
scanf("%d",&m);
while(m--){
int t,u,v,k;
scanf("%d%d",&t,&u);
if(t==1){
scanf("%d%d",&v,&k);
update_path(u,v,k);
}
else if(t==2){
scanf("%d",&k);
update_tree(u,k);
}
else if(t==3){
scanf("%d",&v);
printf("%lld\n",query_path(u,v));
}
else printf("%lld\n",query_tree(u));
}
system("pause");
return 0;
}
我们以依赖关系建立一棵树
对于安装操作,将根到 x x x的路径上的所有但变成 1 1 1,输出多少 0 0 0变成 1 1 1
对于卸载操作,将以 x x x为根的子树全部变成 0 0 0,输出多少 1 1 1变成 0 0 0
通过查询区间和变化完成输出
对于区间赋值,我们用flag=-1
表示无操作,flag=0
表示区间赋值为
0
0
0,
f
l
a
g
=
1
flag=1
flag=1表示区间赋值为
1
1
1
/*************************************************************************
> File Name: p2146[NOI2015]软件包管理器.cpp
> Author: Typedef
> Mail: 1815979752@qq.com
> Created Time: 2021/2/27 15:24:15
> TAG:
************************************************************************/
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+7;
int n,m;
int cnt=0;
int id[N];
int h[N],e[N],ne[N],idx;
int dep[N],sz[N],top[N],fa[N],son[N];
struct Tree{
int l,r,flag,sum;
}tr[N*4];
void add(int a,int b){
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs1(int u,int depth){
dep[u]=depth,sz[u]=1;
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
dfs1(j,depth+1);
sz[u]+=sz[j];
if(sz[son[u]]<sz[j]) son[u]=j;
}
}
void dfs2(int u,int t){
id[u]=++cnt,top[u]=t;
if(!son[u]) return;
dfs2(son[u],t);
for(int i=h[u];~i;i=ne[i]){
int j=e[i];
if(j==son[u]) continue;
dfs2(j,j);
}
}
void pushup(int u){
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
}
void pushdown(int u){
auto &root=tr[u],&left=tr[u<<1],&right=tr[u<<1|1];
if(root.flag!=-1){
left.sum=root.flag*(left.r-left.l+1);
right.sum=root.flag*(right.r-right.l+1);
left.flag=right.flag=root.flag;
root.flag=-1;
}
}
void build(int u,int l,int r){
tr[u]={l,r,-1,0};
if(l==r) return;
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
pushup(u);
}
void update(int u,int l,int r,int k){
if(tr[u].l>=l&&r>=tr[u].r){
tr[u].flag=k;
tr[u].sum=k*(tr[u].r-tr[u].l+1);
return;
}
int mid=tr[u].l+tr[u].r>>1;
pushdown(u);
if(l<=mid) update(u<<1,l,r,k);
if(r>mid) update(u<<1|1,l,r,k);
pushup(u);
}
void update_path(int u,int v,int k){
while(top[u]!=top[v]){
if(dep[u]<dep[top[v]]) swap(u,v);
update(1,id[top[u]],id[u],k);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
update(1,id[v],id[u],k);
}
void update_tree(int u,int k){
update(1,id[u],id[u]+sz[u]-1,k);
}
int main(){
scanf("%d",&n);
memset(h,-1,sizeof(h));
for(int i=2;i<=n;i++){
int p;
scanf("%d",&p);
p++;
add(p,i);
fa[i]=p;
}
dfs1(1,1);
dfs2(1,1);
build(1,1,n);
scanf("%d",&m);
char op[20];
int x;
while(m--){
scanf("%s%d",op,&x);
x++;
if(!strcmp(op,"install")){
int t=tr[1].sum;
update_path(1,x,1);
printf("%d\n",tr[1].sum-t);
}
else{
int t=tr[1].sum;
update_tree(x,0);
printf("%d\n",t-tr[1].sum);
}
}
system("pause");
exit(0);
}