归纳(四):树链剖分
剖分的意义
能用线段树搞想要的信息。
(其实可以只用来求LCA)
需要的东西
七个数组,在两次dfs中处理出来。
dfs1:
dep[]:深度
fa[]:父亲
son[]:重儿子
siz[]:子树大小(包括自己)
dfs2:
indexx:新的编号
id[]:新的编号
top[]:链顶
int son[maxn],fa[maxn],dep[maxn],siz[maxn];
int id[maxn],indexx,top[maxn];
void dfs1(int nd,int p,int deep) {
dep[nd]=deep;
fa[nd]=p;
siz[nd]=1;
int maxson=-1;
for(int i=h[nd];~i;i=e[i].nxt) {
if(e[i].v==p) continue;
dfs1(e[i].v,nd,deep+1);
siz[nd]+=siz[e[i].v];
if(siz[e[i].v]>maxson)
son[nd]=e[i].v,maxson=siz[e[i].v];
}
}
void dfs2(int nd,int chain) {
id[nd]=++indexx;
top[nd]=chain;
if(!son[nd]) return ;
dfs2(son[nd],chain);
for(int i=h[nd];~i;i=e[i].nxt) {
if(e[i].v==fa[nd] || e[i].v==son[nd]) continue;
dfs2(e[i].v,e[i].v);
}
}
其实很好理解。
但不加线段树的树链剖分是没有灵魂的。
树链剖分求LCA
最最最基础的操作了。
原理是不在同一条重链上,就直接跳链。
int lca(int lnd,int rnd) {
while(top[lnd]!=top[rnd]) {
if(dep[top[lnd]]<dep[top[rnd]]) std::swap(lnd,rnd);
lnd=fa[top[lnd]];
}
return dep[lnd]<dep[rnd]?lnd:rnd;
}
【模板】树链剖分
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
用到了线段树。
重链由于重新编号,可以保证编号连续。
同时因为dfs序的原因,一个节点和它的子树的编号也是连续的。
分别可以解决1、2和3、4操作。
子树的操作十分简单,直接modify或者query,左端点是id[nd],右端点是id[nd]+siz[nd]-1。
路径与求LCA同理。
void Pmodify(int lnd,int rnd,int c) {
while(top[lnd]!=top[rnd]) {
if(dep[top[lnd]]<dep[top[rnd]]) std::swap(lnd,rnd);
modify(1,1,n,id[top[lnd]],id[lnd],c);
lnd=fa[top[lnd]];
}
if(dep[lnd]>dep[rnd]) std::swap(lnd,rnd);
modify(1,1,n,id[lnd],id[rnd],c);
}
int Pquery(int lnd,int rnd) {
int ans=0;
while(top[lnd]!=top[rnd]) {
if(dep[top[lnd]]<dep[top[rnd]]) std::swap(lnd,rnd);
vall=0;
query(1,1,n,id[top[lnd]],id[lnd]);
ans=(ans+vall)%mod;
lnd=fa[top[lnd]];
}
if(dep[lnd]>dep[rnd]) std::swap(lnd,rnd);
vall=0;
query(1,1,n,id[lnd],id[rnd]);
return (ans+vall)%mod;
}
void Smodify(int nd,int c) {
modify(1,1,n,id[nd],id[nd]+siz[nd]-1,c);
}
int Squery(int nd) {
vall=0;
query(1,1,n,id[nd],id[nd]+siz[nd]-1);
return vall;
}
不要在意query是void型的事情,已经摒弃这种写法了。
例题(一)
[NOI2015]软件包管理器
安装是把根到节点的路径改为1,卸载是子树改为0。
注意lazy_tag是有可能是0的,要把laz初值赋成-1。
其实这个是基本常识。
比如:
Challenge 5
我二月份就已经错过了。
其他题基本都是大同小异。
例题(二)
加油吧!
我才不会呢!
代码
#include<bits/stdc++.h>
const int maxn=3e4+5;
const int oo=0x3f3f3f3f;
inline int read() {
int x,f=1;char ch;while(!isdigit(ch=getchar())) (ch=='-') && (f=-1);
for(x=ch^'0';isdigit(ch=getchar());x=(x<<3)+(x<<1)+(ch^'0'));
return f*x;
}
int n;
struct Edge {
int nxt,v;
}e[maxn<<1];
int h[maxn],tot,w[maxn],wt[maxn];
void add_edge(int u,int v) {
e[++tot].v=v;
e[tot].nxt=h[u];
h[u]=tot;
}
int vmax[maxn<<2],sum[maxn<<2],vall;
inline void pushup(int rt) {
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
vmax[rt]=std::max(vmax[rt<<1],vmax[rt<<1|1]);
}
void build(int rt,int l,int r) {
if(l==r) {
sum[rt]=vmax[rt]=wt[l];
return ;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void modify(int rt,int l,int r,int pos,int c) {
if(l==r) {
sum[rt]=vmax[rt]=c;
return ;
}
int mid=(l+r)>>1;
if(pos<=mid) modify(rt<<1,l,mid,pos,c);
else modify(rt<<1|1,mid+1,r,pos,c);
pushup(rt);
}
int Squery(int rt,int l,int r,int L,int R) {
int mid=(l+r)>>1,ans=0;
if(L<=l && r<=R) return sum[rt];
if(L<=mid) ans+=Squery(rt<<1,l,mid,L,R);
if(R>mid) ans+=Squery(rt<<1|1,mid+1,r,L,R);
pushup(rt);
return ans;
}
int Mquery(int rt,int l,int r,int L,int R) {
int mid=(l+r)>>1,ans=-oo;
if(L<=l && r<=R) return vmax[rt];
if(L<=mid) ans=std::max(ans,Mquery(rt<<1,l,mid,L,R));
if(R>mid) ans=std::max(ans,Mquery(rt<<1|1,mid+1,r,L,R));
pushup(rt);
return ans;
}
int son[maxn],fa[maxn],siz[maxn],dep[maxn];
int indexx,top[maxn],id[maxn];
void dfs1(int nd,int p,int deep) {
dep[nd]=deep;
fa[nd]=p;
siz[nd]=1;
int maxson=-1;
for(int i=h[nd];~i;i=e[i].nxt) {
if(e[i].v==p) continue;
dfs1(e[i].v,nd,deep+1);
siz[nd]+=siz[e[i].v];
if(siz[e[i].v]>maxson) {
son[nd]=e[i].v;
maxson=siz[e[i].v];
}
}
}
void dfs2(int nd,int chain) {
id[nd]=++indexx;
wt[indexx]=w[nd];
top[nd]=chain;
if(!son[nd]) return ;
dfs2(son[nd],chain);
for(int i=h[nd];~i;i=e[i].nxt) {
if(e[i].v==fa[nd] || e[i].v==son[nd]) continue;
dfs2(e[i].v,e[i].v);
}
}
int qmax(int lnd,int rnd) {
int ans=-oo;
while(top[lnd]!=top[rnd]) {
if(dep[top[lnd]]<dep[top[rnd]]) std::swap(lnd,rnd);
ans=std::max(ans,Mquery(1,1,n,id[top[lnd]],id[lnd]));
lnd=fa[top[lnd]];
}
if(dep[lnd]>dep[rnd]) std::swap(lnd,rnd);
return std::max(ans,Mquery(1,1,n,id[lnd],id[rnd]));
}
int qsum(int lnd,int rnd) {
int ans=0;
while(top[lnd]!=top[rnd]) {
if(dep[top[lnd]]<dep[top[rnd]]) std::swap(lnd,rnd);
ans+=Squery(1,1,n,id[top[lnd]],id[lnd]);
lnd=fa[top[lnd]];
}
if(dep[lnd]>dep[rnd]) std::swap(lnd,rnd);
ans+=Squery(1,1,n,id[lnd],id[rnd]);
return ans;
}
int main() {
memset(h,-1,sizeof(h));
n=read();
for(int i=1;i<n;i++) {
int a=read(),b=read();
add_edge(a,b);
add_edge(b,a);
}
for(int i=1;i<=n;i++) w[i]=read();
dfs1(1,0,1);
dfs2(1,1);
build(1,1,n);
int q=read();
while(q--) {
char s[10];
scanf("%s",s);
int x=read(),y=read();
switch(s[1]) {
case 'H':modify(1,1,n,id[x],y);break;
case 'M':printf("%d\n",qmax(x,y));break;
case 'S':printf("%d\n",qsum(x,y));break;
default:break;
}
}
return 0;
}