【学习笔记】(11) 树链剖分——再战三百回
树链剖分,顾名思义,就是将树分割成若干条链的形式,以维护树上路径的信息。
重链剖分
这里给出一些定义:
- 重儿子:表示其子节点中子树最大的子结点
- 轻儿子:不是重儿子的子节点
- 重边:父节点到重儿子的边
- 轻边:父节点到轻儿子的边
- 重链:若干条首尾衔接的重边构成的链
这里引用一下 OI Wiki 的图
实现
树剖的实现依靠两个 dfs
dfs1(t, f)
- \(d[x]\) 表示 \(x\) 节点的深度
- \(fa[x]\) 表示 \(x\) 节点的父亲
- \(siz[x]\) 表示 \(x\) 节点的子树大小
- \(son[x]\) 表示 \(x\) 节点的重儿子
void dfs1(int x,int f){
d[x]=d[f]+1,fa[x]=f,siz[x]=1;
int maxson=-1;
for(int i=Head[x];i;i=Next[i]){
int y=to[i];if(y==f) continue;
dfs1(y,x);siz[x]+=siz[y];
if(siz[y]>maxson) son[x]=y,maxson=siz[y];
}
}
dfs2(x, topf)
- \(id[x]\) 表示 \(x\) 节点 新的编号(即图中的 DFN 序)
- \(val[x]\) 表示编号 \(x\) 的权值
- \(top[x]\) 表示 \(x\) 节点所在重链的顶部节点(深度最小)
顺序:先处理重儿子再处理轻儿子,理由后面说
void dfs2(int x,int topf){
id[x]=++t,val[t]=a[x],top[x]=topf;
if(!son[x]) return ;
dfs2(son[x],topf);
for(int i=Head[x];i;i=Next[i]){
int y=to[i];if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
一些性质
-
性质1:如果 \((u,v)\) 为轻边,则 \(size(v) \le size(u) / 2\)。
- 证明:反证法。如果 \(size(v) > size(u)/2\),则 \(size(v)\) 必然比其他儿子的 \(size\) 要大,那么 \((u,v)\) 必然为重边,与 \((u,v)\) 为轻边相矛盾。
-
性质2:从根到某一个点 \(V\) 的路径上的轻边个数不大于 \(\log n\)
- 证明:当\(v\)为叶子节点时轻边个数最多。由性质 1 可得,每经过一条轻边,子树的节点数至少比原来少一半,所以至多经过 \(log n\)条轻边就到达叶子节点了。
-
性质3:对于树上的每个点到根上的路径都有不超过 \(\log n\) 条轻边 和 \(\log n\) 条重链。
- 证明:显然每条重链的起点和重点都是由轻边连接的,而由性质 2 可知,每个点到根节点的轻边个数最多为 \(\log n\),所以重链数量也不超过 \(\log n\)。
我们可以发现,一棵子树内的 DFN 序 是连续的,重链上的 DFN 序也是连续的,所以可以用线段树来维护区间信息。
常见应用
路径上维护
例如求 \((u,v)\) 路径上的权值和。
我们可以考虑将 \((u,v)\) 拆分成若干条重链(实际上这个过程就是在找LCA)。
假定 \(top[u]\) 与 \(top[v]\) 不同,那么 LCA 肯定不可能在 \(top\) 深度较大的那条重链上,所以我们先处理 \(top\) 深度较大的点。
假设是 \(u\), 则我们可以直接跳到 \(fa[top[u]]\) 处,且跳过的这段 DFN 序是连续的,所以这段在线段树中就是区间 \([id[top[x]],id[x] ]\)。当 \(top[u]\) 与 \(top[v]\) 相等,说明它们走到了同一条重链上,这时他们之间的路径也是连续的一段,且 \(u,v\) 中深度小的那个点就是 \(LCA\)。
int qSum(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ans+=QuerySum(1,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans+=QuerySum(1,1,n,id[x],id[y]);
return ans;
}
子树维护
例如将子树 \(x\) 内的点都加上 \(k\)
由于子树内的 DFN 序是连续的,所以可以将子树用 DFN 序转成一段区间,用线段树维护即可。
void UpDateSon(int x,int k){
UpDate(1,1,n,id[x],id[x]+siz[x]-1,k);
}
求最近公共祖先
与路径维护操作相似。
不断向上跳重链,当跳到同一条重链上时,深度较小的结点即为 LCA。
向上跳重链时需要先跳所在重链顶端深度较大的那个。
int LCA(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
return x;
}
P2590 [ZJOI2008]树的统计
根据题面以及以上的性质,你的线段树需要维护三种操作:
- 单点修改;
- 区间查询最大值;
- 区间查询和。
路径维护,树链剖分基操(码风可能有点丑,毕竟是之前打的)
点击查看代码
#include<bits/stdc++.h>
#define N 30005
#define ls u<<1
#define rs u<<1|1
using namespace std;
void read(int &o) {
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
o=x*f;
}
void write(int x){
if(x<0){putchar('-');x*=-1;}
if(x>9) write(x/10);
putchar(x%10+'0');
}
int n,q;
int tot,Head[N],to[N<<1],Next[N<<1];
int val[N],d[N],fa[N],top[N],size[N],son[N];
int V[N],id[N],cnt;
int sum[N<<4],maxn[N<<4];
void PushUp(int u){
sum[u]=sum[ls]+sum[rs];
maxn[u]=max(maxn[ls],maxn[rs]);
}
void add(int u,int v){
to[++tot]=v,Next[tot]=Head[u],Head[u]=tot;
}
void Build(int u,int l,int r){
if(l==r){
sum[u]=maxn[u]=V[l];
return ;
}
int mid=(l+r)>>1;
Build(ls,l,mid);
Build(rs,mid+1,r);
PushUp(u);
}
void dfs1(int x,int f,int dep){
d[x]=dep,size[x]=1,fa[x]=f;
int maxson=-1;
for(int i=Head[x];i;i=Next[i]){
int y=to[i];
if(y==f) continue;
dfs1(y,x,dep+1);
size[x]+=size[y];
if(size[y]>maxson) maxson=size[y],son[x]=y;
}
}
void dfs2(int x,int topf){
id[x]=++cnt,V[cnt]=val[x],top[x]=topf;
if(!son[x]) return ;
dfs2(son[x],topf);
for(int i=Head[x];i;i=Next[i]){
int y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
void UpDate(int u,int l,int r,int pos,int v){
if(l==r){
sum[u]=maxn[u]=v;
return ;
}
int mid=(l+r)>>1;
if(pos<=mid) UpDate(ls,l,mid,pos,v);
else UpDate(rs,mid+1,r,pos,v);
PushUp(u);
}
int QueryMax(int u,int l,int r,int L,int R){
if(L<=l&&r<=R){
return maxn[u];
}
int mid=(l+r)>>1,ans=-0x3f3f3f3f;
if(L<=mid) ans=max(ans,QueryMax(ls,l,mid,L,R));
if(R>mid) ans=max(ans,QueryMax(rs,mid+1,r,L,R));
return ans;
}
int QuerySum(int u,int l,int r,int L,int R){
if(L<=l&&r<=R) return sum[u];
int mid=(l+r)>>1,ans=0;
if(L<=mid) ans+=QuerySum(ls,l,mid,L,R);
if(R>mid) ans+=QuerySum(rs,mid+1,r,L,R);
return ans;
}
int qSum(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ans+=QuerySum(1,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans+=QuerySum(1,1,n,id[x],id[y]);
return ans;
}
int LCA(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
return x;
}
int qMax(int x,int y){
int ans=-0x3f3f3f3f;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y );
ans=max(ans,QueryMax(1,1,n,id[top[x]],id[x]));
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans=max(ans,QueryMax(1,1,n,id[x],id[y]));
return ans;
}
int main(){
read(n);
for(int i=1;i<n;i++){
int u,v;
read(u),read(v);
add(u,v),add(v,u);
}
for(int i=1;i<=n;i++) read(val[i]);
dfs1(1,0,1);
dfs2(1,1);
Build(1,1,n);
read(q);
while(q--){
char opt[10];
int u,v;
scanf("%s",opt);
scanf("%d%d",&u,&v);
if(opt[1]=='M'){
write(qMax(u,v)),printf("\n");
}else if(opt[1]=='S'){
write(qSum(u,v)),printf("\n");
}else{
UpDate(1,1,n,id[u], v);
}
}
return 0;
}
P3384 【模板】重链剖分/树链剖分
模板题,操作上面都提到过。
点击查看代码
#include<bits/stdc++.h>
#define N 30005
#define ls u<<1
#define rs u<<1|1
using namespace std;
void read(int &o) {
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
o=x*f;
}
void write(int x){
if(x<0){putchar('-');x*=-1;}
if(x>9) write(x/10);
putchar(x%10+'0');
}
int n,q;
int tot,Head[N],to[N<<1],Next[N<<1];
int val[N],d[N],fa[N],top[N],size[N],son[N];
int V[N],id[N],cnt;
int sum[N<<4],maxn[N<<4];
void PushUp(int u){
sum[u]=sum[ls]+sum[rs];
maxn[u]=max(maxn[ls],maxn[rs]);
}
void add(int u,int v){
to[++tot]=v,Next[tot]=Head[u],Head[u]=tot;
}
void Build(int u,int l,int r){
if(l==r){
sum[u]=maxn[u]=V[l];
return ;
}
int mid=(l+r)>>1;
Build(ls,l,mid);
Build(rs,mid+1,r);
PushUp(u);
}
void dfs1(int x,int f,int dep){
d[x]=dep,size[x]=1,fa[x]=f;
int maxson=-1;
for(int i=Head[x];i;i=Next[i]){
int y=to[i];
if(y==f) continue;
dfs1(y,x,dep+1);
size[x]+=size[y];
if(size[y]>maxson) maxson=size[y],son[x]=y;
}
}
void dfs2(int x,int topf){
id[x]=++cnt,V[cnt]=val[x],top[x]=topf;
if(!son[x]) return ;
dfs2(son[x],topf);
for(int i=Head[x];i;i=Next[i]){
int y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
void UpDate(int u,int l,int r,int pos,int v){
if(l==r){
sum[u]=maxn[u]=v;
return ;
}
int mid=(l+r)>>1;
if(pos<=mid) UpDate(ls,l,mid,pos,v);
else UpDate(rs,mid+1,r,pos,v);
PushUp(u);
}
int QueryMax(int u,int l,int r,int L,int R){
if(L<=l&&r<=R){
return maxn[u];
}
int mid=(l+r)>>1,ans=-0x3f3f3f3f;
if(L<=mid) ans=max(ans,QueryMax(ls,l,mid,L,R));
if(R>mid) ans=max(ans,QueryMax(rs,mid+1,r,L,R));
return ans;
}
int QuerySum(int u,int l,int r,int L,int R){
if(L<=l&&r<=R) return sum[u];
int mid=(l+r)>>1,ans=0;
if(L<=mid) ans+=QuerySum(ls,l,mid,L,R);
if(R>mid) ans+=QuerySum(rs,mid+1,r,L,R);
return ans;
}
int qSum(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ans+=QuerySum(1,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans+=QuerySum(1,1,n,id[x],id[y]);
return ans;
}
int LCA(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
return x;
}
int qMax(int x,int y){
int ans=-0x3f3f3f3f;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y );
ans=max(ans,QueryMax(1,1,n,id[top[x]],id[x]));
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans=max(ans,QueryMax(1,1,n,id[x],id[y]));
return ans;
}
int main(){
read(n);
for(int i=1;i<n;i++){
int u,v;
read(u),read(v);
add(u,v),add(v,u);
}
for(int i=1;i<=n;i++) read(val[i]);
dfs1(1,0,1);
dfs2(1,1);
Build(1,1,n);
read(q);
while(q--){
char opt[10];
int u,v;
scanf("%s",opt);
scanf("%d%d",&u,&v);
if(opt[1]=='M'){
write(qMax(u,v)),printf("\n");
}else if(opt[1]=='S'){
write(qSum(u,v)),printf("\n");
}else{
UpDate(1,1,n,id[u], v);
}
}
return 0;
}
P2486 [SDOI2011]染色
就是线段树操作改了一下,左端点的颜色 \(lc\),右端点的颜色 \(rc\),区间成段更新的标记 \(lazy\),区间有多少颜色段 \(val\)。
区间合并时如果左子树右端点与右子树左端点颜色相同,那么总区间答案要 -1 。
在树剖的时候也是同理,要记录链的左右端点颜色,如果遇上一条相连的链左端点(深度小的编号小)颜色相同,那么也要 -1。
在最后在 \(x,y\)在同一条链上也是同理。
细节详见代码
点击查看代码
#include<bits/stdc++.h>
#define N 100005
#define ls u<<1
#define rs u<<1|1
using namespace std;
int n,m;
int tot,Head[N],to[N<<2],Next[N<<1];
int cnt,w[N],size[N],V[N],fa[N],top[N],d[N],son[N],id[N];
int val[N<<4],lazy[N<<4],lc[N<<4],rc[N<<4];
void add(int u,int v){
to[++tot]=v,Next[tot]=Head[u],Head[u]=tot;
}
void PushUp(int u){
val[u]=val[ls]+val[rs];
if(rc[ls]==lc[rs]) val[u]--;
lc[u]=lc[ls],rc[u]=rc[rs];
}
void PushDown(int u){
if(lazy[u]){
lc[ls]=rc[ls]=lc[rs]=rc[rs]=lazy[ls]=lazy[rs]=lazy[u];
val[ls]=val[rs]=1;
lazy[u]=0;
}
}
void Build(int u,int l,int r){
if(l>=r){
val[u]=1;
lc[u]=rc[u]=V[l];
return ;
}
int mid=(l+r)>>1;
Build(ls,l,mid);
Build(rs,mid+1,r);
PushUp(u);
}
void dfs1(int x,int f,int dep){
d[x]=dep,size[x]=1,fa[x]=f;
int maxson=-1;
for(int i=Head[x];i;i=Next[i]){
int y=to[i];
if(y==f) continue;
dfs1(y,x,dep+1);
size[x]+=size[y];
if(size[y]>maxson) maxson=size[y],son[x]=y;
}
}
void dfs2(int x,int topf){
id[x]=++cnt,top[x]=topf,V[cnt]=w[x];
if(!son[x]) return ;
dfs2(son[x],topf);
for(int i=Head[x];i;i=Next[i]){
int y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
void UpDate(int u,int l,int r,int L,int R,int k){
if(L<=l&&r<=R){
lc[u]=rc[u]=lazy[u]=k;
val[u]=1;
return ;
}
PushDown(u);
int mid=(l+r)>>1;
if(L<=mid) UpDate(ls,l,mid,L,R,k);
if(R>mid) UpDate(rs,mid+1,r,L,R,k);
PushUp(u);
}
void UpDate_C(int x,int y,int k){
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
UpDate(1,1,n,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
UpDate(1,1,n,id[x],id[y],k);
}
int LC,RC;
int Query(int u,int l,int r,int L,int R){
if(L<=l&&r<=R){
if(l==L) LC=lc[u];
if(r==R) RC=rc[u];
return val[u];
}
PushDown(u);
int mid=(l+r)>>1;
if(R<=mid) return Query(ls,l,mid,L,R);
if(L>mid) return Query(rs,mid+1,r,L,R);
int ans=Query(ls,l,mid,L,R)+Query(rs,mid+1,r,L,R);
if(lc[rs]==rc[ls]) ans--;
return ans;
}
int QSum(int x,int y){
int ans=0,ans1=-1,ans2=-1;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y),swap(ans1,ans2);
//ans1表示 x 这条链之前链的最上面节点的颜色,对应线段树的 LC,ans2 表示 y 的
ans+=Query(1,1,n,id[top[x]],id[x]);
if(RC==ans1) ans--; // 相邻重链的节点颜色相同
ans1=LC;
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y),swap(ans1,ans2);
ans+=Query(1,1,n,id[x],id[y]);
if(LC==ans1) ans--;
if(RC==ans2) ans--;
printf("%d\n",ans);
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs1(1,0,1);
dfs2(1,1);
Build(1,1,n);
while(m--){
char opt[2];
int a,b,c;
scanf("%s%d%d",opt,&a,&b);
if(opt[0]=='C'){
scanf("%d",&c);
UpDate_C(a,b,c);
}else{
QSum(a,b);
}
}
return 0;
}
P3313 [SDOI2014]旅行
对于不同的宗教直接建不同的线段树来维护,由于空间会超,再加一个动态开点。
点击查看代码
#include<bits/stdc++.h>
#define N 200005
using namespace std;
void read(int &o) {
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
o=x*f;
}
void write(int x){
if(x<0){putchar('-');x*=-1;}
if(x>9) write(x/10);
putchar(x%10+'0');
}
int n,q,T;
int tot,Head[N],to[N<<1],Next[N<<1];
int d[N],fa[N],top[N],size[N],son[N];
int id[N],cnt,w[N],c[N],root[N];
int sum[N<<4],maxn[N<<4],ls[N<<4],rs[N<<4];
void PushUp(int u){
sum[u]=sum[ls[u]]+sum[rs[u]];
maxn[u]=max(maxn[ls[u]],maxn[rs[u]]);
}
void add(int u,int v){
to[++tot]=v,Next[tot]=Head[u],Head[u]=tot;
}
void dfs1(int x,int f,int dep){
d[x]=dep,size[x]=1,fa[x]=f;
int maxson=-1;
for(int i=Head[x];i;i=Next[i]){
int y=to[i];
if(y==f) continue;
dfs1(y,x,dep+1);
size[x]+=size[y];
if(size[y]>maxson) maxson=size[y],son[x]=y;
}
}
void dfs2(int x,int topf){
id[x]=++cnt,top[x]=topf;
if(!son[x]) return ;
dfs2(son[x],topf);
for(int i=Head[x];i;i=Next[i]){
int y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
void UpDate(int &u,int l,int r,int pos,int v){
if(!u) u=++T;
if(l==r){
sum[u]=maxn[u]=v;
return ;
}
int mid=(l+r)>>1;
if(pos<=mid) UpDate(ls[u],l,mid,pos,v);
else UpDate(rs[u],mid+1,r,pos,v);
PushUp(u);
}
int QueryMax(int u,int l,int r,int L,int R){
if(L<=l&&r<=R) return maxn[u];
int mid=(l+r)>>1,ans=-0x3f3f3f3f;
if(L<=mid) ans=max(ans,QueryMax(ls[u],l,mid,L,R));
if(R>mid) ans=max(ans,QueryMax(rs[u],mid+1,r,L,R));
return ans;
}
int QuerySum(int u,int l,int r,int L,int R){
if(L<=l&&r<=R) return sum[u];
int mid=(l+r)>>1,ans=0;
if(L<=mid) ans+=QuerySum(ls[u],l,mid,L,R);
if(R>mid) ans+=QuerySum(rs[u],mid+1,r,L,R);
return ans;
}
int qSum(int x,int y,int t){
int ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ans+=QuerySum(root[t],1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans+=QuerySum(root[t],1,n,id[x],id[y]);
return ans;
}
int qMax(int x,int y,int t){
int ans=-0x3f3f3f3f;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y );
ans=max(ans,QueryMax(root[t],1,n,id[top[x]],id[x]));
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
ans=max(ans,QueryMax(root[t],1,n,id[x],id[y]));
return ans;
}
int main(){
read(n),read(q);
for(int i=1;i<=n;i++) read(w[i]),read(c[i]);
for(int i=1;i<n;i++){
int u,v;
read(u),read(v);
add(u,v),add(v,u);
}
dfs1(1,0,1);
dfs2(1,1);
for(int i=1;i<=n;i++) UpDate(root[c[i]],1,n,id[i],w[i]);
while(q--){
char opt[10];
int u,v;
scanf("%s",opt);
scanf("%d%d",&u,&v);
if(opt[1]=='M'){
write(qMax(u,v,c[u])),printf("\n");
}else if(opt[1]=='S'){
write(qSum(u,v,c[u])),printf("\n");
}else if(opt[1]=='C'){
UpDate(root[c[u]],1,n,id[u],0);
c[u]=v;
UpDate(root[c[u]],1,n,id[u],w[u]);
}else{
UpDate(root[c[u]],1,n,id[u],v);
w[u]=v;
}
}
return 0;
}
P1505 [国家集训队]旅游
将边权压到深度较大的点,这样可以保证不会有超过 1 条边在同一个点上。
然后就是树剖的模板操作了。
点击查看代码
#include<bits/stdc++.h>
#define ls u << 1
#define rs u << 1 | 1
using namespace std;
const int N = 3e5 + 51;
inline int read(){
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
int n, m, tot, cnt;
int Head[N], to[N << 1], Next[N << 1], edge[N << 1], flag[N << 1];
int fa[N], sz[N], d[N], son[N], top[N], id[N], vv[N], val[N], vi[N], vid[N];
int sum[N << 2], maxn[N << 2], minn[N << 2], lazy[N << 2];
inline void add(int u, int v, int w, int i){
to[++tot] = v, Next[tot] = Head[u], Head[u] = tot, edge[tot] = w, flag[tot] = i;
}
inline void dfs1(int x, int f){
fa[x] = f, d[x] = d[f] + 1, sz[x] = 1;
for(int i = Head[x]; i; i = Next[i]){
int y = to[i]; if(y == f) continue;
dfs1(y, x), sz[x] += sz[y], vv[y] = edge[i], vi[y] = flag[i];
if(sz[y] > sz[son[x]]) son[x] = y;
}
}
inline void dfs2(int x, int topf){
id[x] = ++cnt, top[x] = topf, val[cnt] = vv[x], vid[vi[x]] = cnt;
if(!son[x]) return ;
dfs2(son[x], topf);
for(int i = Head[x]; i; i = Next[i]){
int y = to[i]; if(y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
inline void PushUp(int u){
maxn[u] = max(maxn[ls], maxn[rs]);
sum[u] = sum[ls] + sum[rs];
minn[u] = min(minn[ls], minn[rs]);
}
inline void PushDown(int u){
if(lazy[u]){
lazy[ls] ^= 1, lazy[rs] ^= 1;
sum[ls] = -sum[ls], sum[rs] = -sum[rs];
swap(maxn[ls], minn[ls]), maxn[ls] = -maxn[ls], minn[ls] = -minn[ls];
swap(maxn[rs], minn[rs]), maxn[rs] = -maxn[rs], minn[rs] = -minn[rs];
lazy[u] = 0;
}
}
inline void Build(int u, int l, int r){
if(l == r) return maxn[u] = minn[u] = sum[u] = val[l], void();
int mid = (l + r) >> 1;
Build(ls, l, mid), Build(rs, mid + 1, r);
PushUp(u);
}
inline void UpDate1(int u, int l, int r, int L, int R){
if(L <= l && r <= R){
sum[u] = -sum[u], swap(maxn[u], minn[u]);
maxn[u] = -maxn[u], minn[u] = -minn[u], lazy[u] ^= 1;
return ;
}
PushDown(u);
int mid = (l + r) >> 1;
if(L <= mid) UpDate1(ls, l, mid, L, R);
if(R > mid) UpDate1(rs, mid + 1, r, L, R);
PushUp(u);
}
inline int QueryMax(int u, int l, int r, int L, int R){
if(L <= l && r <= R) return maxn[u];
PushDown(u);
int mid = (l + r) >> 1, ans = -1e9;
if(L <= mid) ans = max(ans, QueryMax(ls, l, mid, L, R));
if(R > mid) ans = max(ans, QueryMax(rs, mid + 1, r, L, R));
return ans;
}
inline int QueryMin(int u, int l, int r, int L, int R){
if(L <= l && r <= R) return minn[u];
PushDown(u);
int mid = (l + r) >> 1, ans = 1e9;
if(L <= mid) ans = min(ans, QueryMin(ls, l, mid, L, R));
if(R > mid) ans = min(ans, QueryMin(rs, mid + 1, r, L, R));
return ans;
}
inline int QuerySum(int u, int l, int r, int L, int R){
if(L <= l && r <= R) return sum[u];
PushDown(u);
int mid = (l + r) >> 1, ans = 0;
if(L <= mid) ans += QuerySum(ls, l, mid, L, R);
if(R > mid) ans += QuerySum(rs, mid + 1, r, L, R);
return ans;
}
inline void UpRoad(int x, int y){
while(top[x] != top[y]){
if(d[top[x]] < d[top[y]]) swap(x, y);
UpDate1(1, 1, n, id[top[x]], id[x]);
x = fa[top[x]];
}
if(d[x] > d[y]) swap(x, y);
UpDate1(1, 1, n, id[x] + 1, id[y]);
}
inline int Qmax(int x, int y){
int ans = -1e9;
while(top[x] != top[y]){
if(d[top[x]] < d[top[y]]) swap(x, y);
ans = max(QueryMax(1, 1, n, id[top[x]], id[x]), ans);
x = fa[top[x]];
}
if(d[x] > d[y]) swap(x, y);
ans = max(QueryMax(1, 1, n, id[x] + 1, id[y]), ans);
return ans;
}
inline int Qmin(int x, int y){
int ans = 1e9;
while(top[x] != top[y]){
if(d[top[x]] < d[top[y]]) swap(x, y);
ans = min(QueryMin(1, 1, n, id[top[x]], id[x]), ans);
x = fa[top[x]];
}
if(d[x] > d[y]) swap(x, y);
ans = min(QueryMin(1, 1, n, id[x] + 1, id[y]), ans);
return ans;
}
inline int Qsum(int x, int y){
int ans = 0;
while(top[x] != top[y]){
if(d[top[x]] < d[top[y]]) swap(x, y);
ans += QuerySum(1, 1, n, id[top[x]], id[x]);
x = fa[top[x]];
}
if(d[x] > d[y]) swap(x, y);
ans += QuerySum(1, 1, n, id[x] + 1, id[y]);
return ans;
}
inline void UpDate2(int u, int l, int r, int x, int v){
if(l == r) return maxn[u] = minn[u] = sum[u] = v, void();
PushDown(u);
int mid = (l + r) >> 1;
if(x <= mid) UpDate2(ls, l, mid, x, v);
else UpDate2(rs, mid + 1, r, x, v);
PushUp(u);
}
int main(){
// freopen("1.in", "r", stdin);
// freopen("1.out", "w", stdout);
n = read();
for(int i = 1; i < n; ++i){
int u = read() + 1, v = read() + 1, w = read();
add(u, v, w, i), add(v, u, w, i);
}
dfs1(1, 0), dfs2(1, 1);
Build(1, 1, n);
m = read();
while(m--){
char opt[51]; scanf("%s", opt + 1);
int x = read() + 1, y = read() + 1;
if(opt[1] == 'C') UpDate2(1, 1, n, vid[x - 1], y - 1);
else if(opt[1] == 'N') UpRoad(x, y);
else if(opt[1] == 'S') printf("%d\n", Qsum(x, y));
else if(opt[2] == 'A') printf("%d\n", Qmax(x, y));
else if(opt[2] == 'I') printf("%d\n", Qmin(x, y));
}
return 0;
}
P4211 [LNOI2014]LCA
喵喵题。
我们可以将一个点的深度转为这个点到根的路径上有多少个点。
显然,对于每个询问,我们要求的 \(z\) 与 \(i\in [l,r]\) 的 lca 都在 \(z\) 到根的路径上。
那么关键就在于如何将这些 lca 找到并统计答案,发现 \(lca(x,y)\) 其实就是 \(x\) 到根和 \(y\) 到根的路径相交的第一个节点,那么我们显然可以将 \(x\) 到根的路径先染上颜色,再从 \(y\) 爬上根,碰到第一个有颜色的节点就是 \(lca(x,y)\) 了。
那么我们可以将每个节点 \(i\) 到根上的路径都 + 1 ,再从 \(z\) 走到根统计答案即可,即用树剖统计路径和。
发现时间复杂度不太对,考虑优化。
发现有些点是重复做的,显然可以优化掉。可以用差分来优化,将区间\([l,r]\) 按 \(r\) 从小到大排序,答案与 \(l-1\) 的减一减即可,具体见代码。
点击查看代码
#include<bits/stdc++.h>
#define ls u << 1
#define rs u << 1 | 1
using namespace std;
const int N = 5e4 + 5l, mod = 201314;
inline int read(){
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
int n, m, tot, cnt;
int Head[N], to[N], Next[N];
int d[N], sz[N], son[N], id[N], fa[N], top[N];
int sum[N << 2], ans1[N], ans2[N], lazy[N << 2];
struct query{
int id, pos, z;
bool flag;
bool operator < (const query &A) const{return pos < A.pos;}
}q[N << 1];
inline void add(int u, int v){
to[++tot] = v, Next[tot] = Head[u], Head[u] = tot;
}
inline void dfs1(int x, int f){
fa[x] = f, d[x] = d[f] + 1, sz[x] = 1;
for(int i = Head[x]; i; i = Next[i]){
int y = to[i]; if(y == f) continue;
dfs1(y, x); sz[x] += sz[y];
if(sz[y] > sz[son[x]]) son[x] = y;
}
}
inline void dfs2(int x, int topf){
id[x] = ++cnt, top[x] = topf;
if(!son[x]) return ;
dfs2(son[x], topf);
for(int i = Head[x]; i; i = Next[i]){
int y = to[i]; if(y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
inline void PushDown(int u, int l, int r){
if(lazy[u]){
int mid = (l + r) >> 1;
lazy[ls] += lazy[u], lazy[rs] += lazy[u], lazy[ls] %= mod, lazy[rs] %= mod;
sum[ls] = (sum[ls] + lazy[u] * (mid - l + 1)) % mod, sum[rs] = (sum[rs] + lazy[u] * (r - mid)) % mod;
lazy[u] = 0;
}
}
inline void PushUp(int u){
sum[u] = (sum[ls] + sum[rs]) % mod;
}
inline void UpDate(int u, int l, int r, int L, int R){
if(L <= l && r <= R) return sum[u] = (sum[u] + r - l + 1) % mod, lazy[u] = (lazy[u] + 1) % mod, void();
PushDown(u, l, r);
int mid = (l + r) >> 1;
if(L <= mid) UpDate(ls, l, mid, L, R);
if(R > mid) UpDate(rs, mid + 1, r, L, R);
PushUp(u);
}
inline void UpRoad(int x, int y){
while(top[x] != top[y]){
if(d[top[x]] < d[top[y]]) swap(x, y);
UpDate(1, 1, n, id[top[x]], id[x]);
x = fa[top[x]];
}
if(d[x] > d[y]) swap(x, y);
UpDate(1, 1, n, id[x], id[y]);
}
inline int Query(int u, int l, int r, int L, int R){
if(L <= l && r <= R) return sum[u];
PushDown(u, l, r);
int mid = (l + r) >> 1, ans = 0;
if(L <= mid) ans += Query(ls, l, mid, L, R);
if(R > mid) ans += Query(rs, mid + 1, r, L, R);
return ans % mod;
}
inline int Qsum(int x, int y){
int ans = 0;
while(top[x] != top[y]){
if(d[top[x]] < d[top[y]]) swap(x, y);
ans += Query(1, 1, n, id[top[x]], id[x]), ans %= mod;
x = fa[top[x]];
}
if(d[x] > d[y]) swap(x, y);
ans += Query(1, 1, n, id[x], id[y]);
return ans % mod;
}
int main(){
n = read(), m = read();
for(int i = 2; i <= n; ++i){
int u = read() + 1;
add(u, i);
}
dfs1(1, 0), dfs2(1, 1), tot = 0;
for(int i = 1; i <= m; ++i){
int l = read() + 1, r = read() + 1, z = read() + 1;
q[++tot] = (query){i, l - 1, z, 0};
q[++tot] = (query){i, r, z, 1};
}
sort(q + 1, q + 1 + tot);
int now = 0;
for(int i = 1; i <= tot; ++i){
while(now < q[i].pos) UpRoad(1, ++now);
int num = q[i].id;
if(q[i].flag) ans2[num] = Qsum(1, q[i].z);
else ans1[num] = Qsum(1, q[i].z);
}
for(int i = 1; i <= m; ++i) printf("%d\n", (ans2[i] - ans1[i] + mod) % mod);
return 0;
}