树链剖分模板题
树链剖分
树链剖分就是把一个树有顺序地分成几个链,记录每个点的顺序,存在数组中,就可以用线段树维护树上的一些操作
以下是几个模板题:
《信息学奥赛一本通提高篇》上的模板是这样的:
1 #include<iostream> 2 #include<algorithm> 3 #include<cstdio> 4 #include<cstring> 5 using namespace std; 6 const int N=31000; 7 const int M=124000; 8 int n,m,Summ,Maxx; 9 int seg[N],rev[M],size[N],son[N],top[N],dep[N]; 10 int sum[M],num[M],father[M],Max[M]; 11 int first[M],next[M],go[M]; 12 13 void query(int k,int l,int r,int L,int R) //区间询问 14 { 15 if(L<=l&&r<=R) 16 { 17 Summ+=sum[k]; 18 Maxx=max(Maxx,Max[k]); 19 return; 20 } 21 int mid=(l+r)>>1; 22 if(mid>=L) query(k<<1,l,mid,L,R); 23 if(mid+1<=R) query(k<<1|1,mid+1,r,L,R); 24 } 25 26 void change(int k,int l,int r,int Val,int pos) //单点修改 27 { 28 if(l==r&&r==pos) 29 { 30 sum[k]=Val; 31 Max[k]=Val; 32 return; 33 } 34 int mid=(l+r)>>1; 35 if(mid>=pos) change(k<<1,l,mid,Val,pos); 36 if(mid+1<=pos) change(k<<1|1,mid+1,r,Val,pos); 37 sum[k]=sum[k<<1]+sum[k<<1|1]; 38 Max[k]=max(Max[k<<1],Max[k<<1|1]); 39 } 40 41 void dfs1(int u,int f) 42 { 43 int e,v; 44 size[u]=1; 45 father[u]=f; 46 dep[u]=dep[f]+1; 47 for(e=first[u];v=go[e],e;e=next[e]) 48 if(v!=f) 49 { 50 dfs1(v,u); 51 size[u]+=size[v]; 52 if(size[v]>size[son[u]]) 53 son[u]=v; 54 } 55 } 56 57 void dfs2(int u,int f) 58 { 59 int e,v; 60 if(son[u]) 61 { 62 seg[son[u]]=++seg[0]; 63 top[son[u]]=top[u]; 64 rev[seg[0]]=son[u]; 65 dfs2(son[u],u); 66 } 67 for(e=first[u];v=go[e],e;e=next[e]) 68 if(!top[v]) 69 { 70 seg[v]=++seg[0]; 71 rev[seg[0]]=v; 72 top[v]=v; 73 dfs2(v,u); 74 } 75 } 76 77 void build(int k,int l,int r) 78 { 79 int mid=(l+r)>>1; 80 if(l==r) 81 { 82 Max[k]=sum[k]=num[rev[l]]; 83 return; 84 } 85 build(k<<1,l,mid); 86 build(k<<1|1,mid+1,r); 87 sum[k]=sum[k<<1]+sum[k<<1|1]; 88 Max[k]=max(Max[k<<1],Max[k<<1|1]); 89 } 90 91 inline int get() 92 { 93 char c; 94 int sign=1; 95 while((c=getchar())<'0'||c>'9') 96 if(c=='-') sign=-1; 97 int res=c-'0'; 98 while((c=getchar())>='0'&&c<='9') 99 res=res*10+c-'0'; 100 return res*sign; 101 } 102 103 int tot; 104 105 inline void add(int x,int y) 106 { 107 next[++tot]=first[x]; 108 first[x]=tot; 109 go[tot]=y; 110 } 111 112 inline void insert(int x,int y) 113 { 114 add(x,y); add(y,x); 115 } 116 117 inline void ask(int x,int y) 118 { 119 int fx=top[x],fy=top[y]; 120 while(fx!=fy) 121 { 122 if(dep[fx]<dep[fy]) swap(x,y),swap(fx,fy); 123 query(1,1,seg[0],seg[fx],seg[x]); 124 x=father[fx];fx=top[x]; 125 } 126 if(dep[x]<dep[y]) swap(x,y); 127 query(1,1,seg[0],seg[y],seg[x]); 128 } 129 130 int main() 131 { 132 int i; 133 n=get(); 134 for(i=1;i<n;i++) 135 insert(get(),get()); 136 for(i=1;i<=n;i++) 137 num[i]=get(); 138 dfs1(1,0); 139 seg[0]=seg[1]=top[1]=rev[1]=1; 140 dfs2(1,0); 141 build(1,1,seg[0]); 142 m=get(); 143 char sr[10]; 144 int u,v; 145 for(i=1;i<=m;i++) 146 { 147 scanf("%s",sr+1); 148 u=get(); 149 v=get(); 150 if(sr[1]=='C') 151 change(1,1,seg[0],v,seg[u]); 152 else 153 { 154 Summ=0; 155 Maxx=-10000000; 156 ask(u,v); 157 if(sr[2]=='M') 158 printf("%d\n",Maxx); 159 else 160 printf("%d\n",Summ); 161 } 162 } 163 return 0; 164 }
1 #include<iostream> 2 #include<cstring> 3 #include<cstdio> 4 5 using namespace std; 6 7 #define N 100010 8 #define M 400040 9 #define lc(p) ((p)<<1) //左儿子 10 #define rc(p) ((p)<<1|1) //右儿子 11 #define mid ((l+r)>>1) 12 13 int fa[N],son[N],size[N],top[N],dep[N],seg[N]; //父亲节点;重儿子;子树大小;链首节点;深度;在线段树中的编号 14 int n,m,R,P,tot=1,next[M],head[M],to[M]; 15 int Sum[M],dealta[M],rev[M],num[N]; //rev[i]:线段树节点为i的点输入的顺序(在num数组的下标) 16 17 inline int read(){ 18 int x=0,f=1; char c=getchar(); 19 while(c<'0'||c>'9') { if(c=='-') f=-1; c=getchar(); } 20 while('0'<=c&&c<='9') { x=(x<<3)+(x<<1)+c-'0'; c=getchar(); } 21 return x*f; 22 } 23 24 inline void add(int x,int y){ 25 to[++tot]=y; 26 next[tot]=head[x]; 27 head[x]=tot; 28 } 29 30 //下为线段树模板 31 inline void push_up(int p){ Sum[p]=(Sum[lc(p)]+Sum[rc(p)])%P; } 32 33 inline void push_down(int p,int l,int r) 34 { 35 int &d=dealta[p]; 36 dealta[lc(p)]=(dealta[lc(p)]+d)%P; 37 dealta[rc(p)]=(dealta[rc(p)]+d)%P; 38 Sum[lc(p)]=(Sum[lc(p)]+((mid-l+1)*d))%P; 39 Sum[rc(p)]=(Sum[rc(p)]+((r-mid)*d))%P; 40 d=0; 41 } 42 43 void build(int p=1,int l=1,int r=seg[0]) 44 { 45 if(l==r) { 46 Sum[p]=num[rev[l]]%P; 47 return; 48 } 49 build(lc(p),l,mid); 50 build(rc(p),mid+1,r); 51 push_up(p); 52 } 53 54 int query(int L,int R,int p=1,int l=1,int r=seg[0]) 55 { 56 if(l>R||r<L) return 0; 57 if(L<=l&&r<=R) return Sum[p]%P; 58 push_down(p,l,r); 59 int ans=0; 60 if(L<=mid) ans+=query(L,R,lc(p),l,mid); 61 if(R>mid) ans+=query(L,R,rc(p),mid+1,r); 62 push_up(p); 63 return ans %P; 64 } 65 66 void update(int L,int R,int Val,int p=1,int l=1,int r=seg[0]) 67 { 68 if(l>R||r<L) return; 69 if(L<=l&&r<=R){ 70 dealta[p]=(dealta[p]+Val)%P; 71 Sum[p]=(Sum[p]+(r-l+1)*Val)%P; 72 return; 73 } 74 push_down(p,l,r); 75 if(L<=mid) update(L,R,Val,lc(p),l,mid); 76 if(R>mid) update(L,R,Val,rc(p),mid+1,r); 77 push_up(p); 78 } 79 80 //第一遍dfs 得到fa,dep,size,son 81 void dfs1(int u,int f){ 82 fa[u]=f; 83 size[u]=1; 84 dep[u]=dep[f]+1; 85 for(int i=head[u];i;i=next[i]) 86 if(to[i]!=f){ 87 int v=to[i]; 88 dfs1(v,u); 89 size[u]+=size[v]; 90 if(size[v]>size[son[u]]) 91 son[u]=v; 92 } 93 } 94 //第二遍dfs 得到seg,top,rev 95 void dfs2(int u,int f){ 96 if(son[u]){ //先搜重儿子,保证每一条链在线段树是连续的区间 97 seg[son[u]]=++seg[0]; 98 top[son[u]]=top[u]; 99 rev[seg[0]]=son[u]; 100 dfs2(son[u],u); 101 } 102 for(int i=head[u];i;i=next[i]) 103 if(!top[to[i]]){ //如果没有更新过 104 int v=to[i]; 105 seg[v]=++seg[0]; 106 top[v]=v; 107 rev[seg[0]]=v; 108 dfs2(v,u); 109 } 110 } 111 112 int ask1(int x,int y){ 113 int fx=top[x],fy=top[y],ans=0; 114 while(fx!=fy){ //若不在同一条链上,就把深度大的向上条,最多跳logn次 115 if(dep[fx]<dep[fy]) { swap(x,y); swap(fx,fy); } 116 ans=(ans+query(seg[fx],seg[x]))%P; 117 x=fa[fx]; fx=top[x]; 118 } 119 if(dep[x]>dep[y]) swap(x,y); 120 ans+=query(seg[x],seg[y]); //已经在同一条链上,直接区间查询 121 return ans%P; 122 } 123 124 void change1(int x,int y,int Val){ 125 int fx=top[x],fy=top[y]; 126 while(fx!=fy){ 127 if(dep[fx]<dep[fy]) { swap(x,y); swap(fx,fy); } 128 update(seg[fx],seg[x],Val); 129 x=fa[fx]; fx=top[x]; 130 } 131 if(dep[x]>dep[y]) swap(x,y); 132 update(seg[x],seg[y],Val); 133 } 134 135 int ask2(int p){ return query(seg[p],seg[p]+size[p]-1); } //一个节点的子树在线段树中是连续的一段区间 136 137 void change2(int p,int Val){ 138 update(seg[p],seg[p]+size[p]-1,Val); 139 } 140 141 int main() 142 { 143 scanf("%d%d%d%d",&n,&m,&R,&P); 144 for(int i=1;i<=n;i++) 145 num[i]=read(); 146 int x,y; 147 for(int i=1;i<n;i++){ 148 x=read(); y=read(); 149 add(x,y); add(y,x); 150 } 151 dfs1(R,0); 152 seg[0]=seg[R]=1; top[R]=rev[1]=R; 153 dfs2(R,0); 154 build(); 155 int t,z; 156 while(m--){ 157 t=read(); 158 switch(t){ 159 case 1:{ 160 x=read(); y=read(); z=read(); 161 change1(x,y,z); 162 break; 163 } 164 case 2:{ 165 x=read(); y=read(); 166 printf("%d\n",ask1(x,y)); 167 break; 168 } 169 case 3:{ 170 x=read(); z=read(); 171 change2(x,z); 172 break; 173 } 174 case 4:{ 175 x=read(); 176 printf("%d\n",ask2(x)); 177 break; 178 } 179 } 180 } 181 return 0; 182 }
1 #include<iostream> 2 #include<cstring> 3 #include<cstdio> 4 using namespace std; 5 #define int long long 6 #define N 100010 7 #define M 400010 8 #define lc(p) (p<<1) 9 #define rc(p) (p<<1|1) 10 #define mid ((l+r)>>1) 11 int num[N],son[N],size[N],fa[N],dep[N]; 12 int seg[N],rev[M],top[M],n,m; 13 int Sum[M],dealta[M]; 14 int head[M],to[M],next[M],tot=1; 15 16 inline int read(){ 17 int x=0,f=1; char c=getchar(); 18 while(c<'0'||c>'9') { if(c=='-') f=-1; c=getchar(); } 19 while('0'<=c&&c<='9') { x=(x<<3)+(x<<1)+c-'0'; c=getchar(); } 20 return x*f; 21 } 22 23 inline void add(int x,int y){ 24 to[++tot]=y; 25 next[tot]=head[x]; 26 head[x]=tot; 27 } 28 29 inline void push_up(int p){ 30 Sum[p]=Sum[lc(p)]+Sum[rc(p)]; 31 } 32 33 inline void f(int p,int l,int r,int d){ 34 dealta[p]+=d; Sum[p]+=(r-l+1)*d; 35 } 36 37 inline void push_down(int p,int l,int r){ 38 int &d=dealta[p]; 39 f(lc(p),l,mid,d); 40 f(rc(p),mid+1,r,d); 41 d=0; 42 } 43 44 void build(int p=1,int l=1,int r=seg[0]){ 45 if(l==r){ 46 Sum[p]=num[rev[l]]; 47 return; 48 } 49 build(lc(p),l,mid); 50 build(rc(p),mid+1,r); 51 push_up(p); 52 } 53 54 int query(int L,int R,int p=1,int l=1,int r=seg[0]) 55 { 56 if(L<=l&&r<=R) 57 return Sum[p]; 58 push_down(p,l,r); 59 int ans=0; 60 if(L<=mid) ans+=query(L,R,lc(p),l,mid); 61 if(R>mid) ans+=query(L,R,rc(p),mid+1,r); 62 push_up(p); 63 return ans; 64 } 65 66 void update(int L,int R,int Val,int p=1,int l=1,int r=seg[0]) 67 { 68 if(L<=l&&r<=R){ 69 f(p,l,r,Val); 70 return; 71 } 72 push_down(p,l,r); 73 if(L<=mid) update(L,R,Val,lc(p),l,mid); 74 if(R>mid) update(L,R,Val,rc(p),mid+1,r); 75 push_up(p); 76 } 77 78 void dfs1(int u,int f){ 79 fa[u]=f; 80 dep[u]=dep[f]+1; 81 size[u]=1; 82 for(int i=head[u];i;i=next[i]) 83 if(to[i]!=f){ 84 int v=to[i]; 85 dfs1(v,u); 86 size[u]+=size[v]; 87 if(size[v]>size[son[u]]) 88 son[u]=v; 89 } 90 } 91 92 void dfs2(int u,int f){ 93 if(son[u]){ 94 int v=son[u]; 95 seg[v]=++seg[0]; 96 top[v]=top[u]; 97 rev[seg[0]]=v; 98 dfs2(v,u); 99 } 100 for(int i=head[u];i;i=next[i]) 101 if(!top[to[i]]){ 102 int v=to[i]; 103 seg[v]=++seg[0]; 104 top[v]=v; 105 rev[seg[0]]=v; 106 dfs2(v,u); 107 } 108 } 109 110 int ask(int x){ 111 int fx=top[x],ans=0; 112 while(fx!=1){ 113 ans+=query(seg[fx],seg[x]); 114 x=fa[fx]; fx=top[x]; 115 } 116 ans+=query(1,seg[x]); 117 return ans; 118 } 119 120 void change(int x,int Val){ 121 update(seg[x],seg[x]+size[x]-1,Val); 122 } 123 #undef int 124 int main() 125 #define int long long 126 { 127 scanf("%lld%lld",&n,&m); 128 for(int i=1;i<=n;i++) 129 num[i]=read(); 130 int x,y,z; 131 for(int i=1;i<n;i++){ 132 x=read(); y=read(); 133 add(x,y); add(y,x); 134 } 135 dfs1(1,0); 136 seg[0]=seg[1]=rev[1]=top[1]=1; 137 dfs2(1,0); 138 build(); 139 while(m--){ 140 x=read(); y=read(); 141 if(x==1){ 142 z=read(); 143 update(seg[y],seg[y],z); 144 } 145 else if(x==2){ 146 z=read(); 147 change(y,z); 148 } 149 else printf("%lld\n",ask(y)); 150 } 151 return 0; 152 }