线段树树链剖分(点权)
模板题:树链剖分
参考博客:树链剖分详解(洛谷模板 P3384)
前置技能:线段树
#define lson (p<<1)
#define rson (p<<1|1)
建树的时候需要注意的:
void build(int s,int t,int p)
{
if(s==t)
{
d[p]=wt[s]%mod;
/*此处wt[]是排序好之后的序号对应的权重*/
return;
}
int m=(s+t)>>1;
build(s,m,lson),build(m+1,t,rson);
d[p]=(d[lson]+d[rson])%mod;
}
树链剖分:
初始化工作:
int dep[maxn],f[maxn],siz[maxn],son[maxn];
int top[maxn],w[maxn<<1],id[maxn],tot=0,n;
void dfs1(int x,int fa,int deep)
{
dep[x]=deep,f[x]=fa,siz[x]=1;
int maxson=-1;
for(int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa) continue;
dfs1(v,x,deep+1);
siz[x]+=siz[v];
if(siz[v]>maxson) maxson=siz[v],son[x]=v;
}
}
void dfs2(int u,int topf)
{
id[u]=++tot;
wt[tot]=w[u];
top[u]=topf;
if(!son[u]) return;
dfs2(son[u],topf);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==f[u]||v==son[u]) continue;
dfs2(v,v);
}
}
区间修改(最短路径上所有节点的值加 k):
void updrange(int x,int y,int k)
{
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(id[top[x]],id[x],1,n,k,1);
x=f[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
update(id[x],id[y],1,n,k,1);
}
区间查询(最短路径上所有节点的值之和):
ll qrange(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=(ans+getsum(id[top[x]],id[x],1,n,1)%mod)%mod;
x=f[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans=(ans+getsum(id[x],id[y],1,n,1)%mod)%mod;
return ans;
}
子树内所有节点值都加上 k:
void updson(int x,int k)
{
update(id[x],id[x]+siz[x]-1,1,n,k,1);
}
子树内所有节点值之和:
ll qson(int x)
{
return getsum(id[x],id[x]+siz[x]-1,1,n,1)%mod;
}
模板题代码:
// Created by CAD on 2019/8/11.
#include <bits/stdc++.h>
#define lson (p<<1)
#define rson (p<<1|1)
using namespace std;
using pii=pair<int, int>;
using piii=pair<pair<int, int>, int>;
using ll=long long;
const int maxn=1e5+5;
/*线段树*/
int d[maxn<<2],wt[maxn],laz[maxn << 2];
ll mod;
void build(int s,int t,int p)
{
if(s==t)
{
d[p]=wt[s]%mod;
return;
}
int m=(s+t)>>1;
build(s,m,lson),build(m+1,t,rson);
d[p]=(d[lson]+d[rson])%mod;
}
void pushdown(int s,int t,int p)
{
int m=(s+t)>>1;
d[lson]=(d[lson]+(m-s+1)*laz[p])%mod,d[rson]=(d[rson]+(t-m)*laz[p])%mod;
laz[lson]=(laz[lson]+laz[p])%mod,laz[rson]=(laz[rson]+laz[p])%mod;
laz[p]=0;
}
void update(int l,int r,int s,int t,int c,int p)
{
if(l<=s&&t<=r)
{
d[p]+=c*(t-s+1);
laz[p]+=c;
return ;
}
if(laz[p]) pushdown(s,t,p);
int m=(s+t)>>1;
if(l<=m) update(l,r,s,m,c,lson);
if(r>m) update(l,r,m+1,t,c,rson);
d[p]=(d[rson]+d[lson])%mod;
}
ll getsum(int l,int r,int s,int t,int p)
{
if(l<=s&&t<=r) return d[p]%mod;
if(laz[p]) pushdown(s,t,p);
int m=(s+t)>>1;
ll sum=0;
if(l<=m) sum=(sum+getsum(l,r,s,m,lson))%mod;
if(r>m) sum=(sum+getsum(l,r,m+1,t,rson))%mod;
return sum%mod;
}
int cnt=0,head[maxn<<1];
struct edge{
int to,next;
}e[maxn<<1];
void add(int u,int v)
{
e[++cnt].to=v;
e[cnt].next=head[u];
head[u]=cnt;
}
int dep[maxn],f[maxn],siz[maxn],son[maxn];
int top[maxn],w[maxn<<1],id[maxn],tot=0,n;
void dfs1(int x,int fa,int deep)
{
dep[x]=deep,f[x]=fa,siz[x]=1;
int maxson=-1;
for(int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa) continue;
dfs1(v,x,deep+1);
siz[x]+=siz[v];
if(siz[v]>maxson) maxson=siz[v],son[x]=v;
}
}
void dfs2(int u,int topf)
{
id[u]=++tot;
wt[tot]=w[u];
top[u]=topf;
if(!son[u]) return;
dfs2(son[u],topf);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==f[u]||v==son[u]) continue;
dfs2(v,v);
}
}
ll qrange(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=(ans+getsum(id[top[x]],id[x],1,n,1)%mod)%mod;
x=f[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans=(ans+getsum(id[x],id[y],1,n,1)%mod)%mod;
return ans;
}
void updrange(int x,int y,int k)
{
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(id[top[x]],id[x],1,n,k,1);
x=f[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
update(id[x],id[y],1,n,k,1);
}
ll qson(int x)
{
return getsum(id[x],id[x]+siz[x]-1,1,n,1)%mod;
}
void updson(int x,int k)
{
update(id[x],id[x]+siz[x]-1,1,n,k,1);
}
ll m,r;
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n>>m>>r>>mod;
for(int i=1;i<=n;++i) cin>>w[i];
for(int i=1,u,v;i<n;++i) cin>>u>>v,add(u,v),add(v,u);
dfs1(r,0,1);
dfs2(r,r);
build(1,n,1);
while(m--)
{
int k,x,y,z;
cin>>k;
if(k==1)
cin>>x>>y>>z,updrange(x,y,z);
else if(k==2)
cin>>x>>y,cout<<qrange(x,y)<<endl;
else if(k==3)
cin>>x>>y,updson(x,y);
else if(k==4)
cin>>x,cout<<qson(x)<<endl;
}
return 0;
}
CAD加油!欢迎跟我一起讨论学习算法,QQ:1401650042