洛谷 P3384 【模板】树链剖分 题解
因为是模板题,所以思维难度并不是很大。
主要思路就是用树剖进行整体架构,用线段树进行求区间和的操作。
其中对每个节点进行基于dfs
序的重编号,以方便用线段树维护区间和。
代码量略大,笔者写了219行。数据结构建议进行封装,以避免命名冲突和理解上的混乱。
#include<bits/stdc++.h>
//强制内联展开
#ifndef ONLINE_JUDGE
#define UMP45 __attribute__((always_inline))
#elif __cplusplus<201103L
#define UMP45 __attribute__((always_inline))
#else
#define UMP45 [[gnu::always_inline]]
#endif // __cplusplus<201103L
typedef long long lli;
constexpr int Size=1e6+1;
using namespace std;
int n,m,r,Mod,opt,x,y,z;
int a[Size],w[Size],head[Size];
namespace Segment_tree
{
lli tr[Size<<2],tag[Size<<2];
lli query(int now,int gl,int gr,int l,int r);
void build(int now,int l,int r);
void update(int now,int gl,int gr,int l,int r,int k);
void push_down(int now,int l,int r);
inline UMP45 void modify(int now,int l,int r,int k)
{
tag[now]+=k;
tr[now]+=k*(r-l+1);
}
inline UMP45 void push_up(int now)
{
tr[now]=(tr[now<<1]+tr[now<<1|1])%Mod;
}
inline UMP45 lli query(int l,int r)//此处重载query与update,方便调用
{
return query(1,l,r,1,n);
}
inline UMP45 void update(int l,int r,int k)
{
update(1,l,r,1,n,k);
}
void build(int now,int l,int r)
{
tag[now]=0;
if(l==r)
return (void)(tr[now]=w[l]%Mod);
int mid=(l+r)>>1;
build(now<<1,l,mid);
build(now<<1|1,mid+1,r);
push_up(now);
}
void update(int now,int gl,int gr,int l,int r,int k)
{
if(gl<=l&&gr>=r){
modify(now,l,r,k);
return;
}
push_down(now,l,r);
int mid=(l+r)>>1;
if(gl<=mid) update(now<<1,gl,gr,l,mid,k);
if(gr>mid) update(now<<1|1,gl,gr,mid+1,r,k);
push_up(now);
}
lli query(int now,int gl,int gr,int l,int r)
{
lli ret=0;
if(gl<=l&&gr>=r) return tr[now];
int mid=(l+r)>>1;
push_down(now,l,r);
if(gl<=mid) ret=(ret+query(now<<1,gl,gr,l,mid))%Mod;
if(gr>mid) ret=(ret+query(now<<1|1,gl,gr,mid+1,r))%Mod;
return ret;
}
void push_down(int now,int l,int r)
{
int mid=(l+r)>>1;
modify(now<<1,l,mid,tag[now]);
modify(now<<1|1,mid+1,r,tag[now]);
tr[now<<1]%=Mod;
tr[now<<1|1]%=Mod;
tag[now]=0;
}
}
struct edge{
int v,nxt;
edge():v(0),nxt(0){}
edge(int _v,int _n):v(_v),nxt(_n){}
}e[Size];
inline UMP45 void add_edge(int u,int v)
{
static int cnt;
e[++cnt]={v,head[u]};
head[u]=cnt;
}
namespace Subdivision
{
struct node{
int id,fa,dep,top,son,size;
node():id(0),fa(0),dep(0),top(0),son(0),size(1){}
node(int _f,int _d):id(0),fa(_f),dep(_d),top(0),son(0),size(1){}
}tr[Size];
void dfs1(int now,int fa,int dep);
void dfs2(int now,int top);
int query(int a,int b);
void update(int a,int b,int k);
inline UMP45 node& ctop(int id)
{
return tr[tr[id].top];
}
inline UMP45 int query(int x)
{
return Segment_tree::query(tr[x].id,tr[x].id+tr[x].size-1);
}
inline UMP45 void update(int x,int k)
{
Segment_tree::update(tr[x].id,tr[x].id+tr[x].size-1,k);
}
inline UMP45 void init()
{
dfs1(r,0,1);
dfs2(r,r);
Segment_tree::build(1,1,n);
}
void dfs1(int now,int fa,int dep)
{
tr[now]={fa,dep};
int maxcnt=-1;
for(int i=head[now];i!=0;i=e[i].nxt){
if(e[i].v==fa)
continue;
dfs1(e[i].v,now,dep+1);
tr[now].size+=tr[e[i].v].size;
if(tr[e[i].v].size>maxcnt){
tr[now].son=e[i].v;
maxcnt=tr[e[i].v].size;
}
}
}
void dfs2(int now,int top)
{
static int cnt;
tr[now].id=++cnt;
tr[now].top=top;
w[cnt]=a[now];
if(tr[now].son==0)
return;
dfs2(tr[now].son,top);
for(int i=head[now];i!=0;i=e[i].nxt)
if(e[i].v!=tr[now].fa&&e[i].v!=tr[now].son)
dfs2(e[i].v,e[i].v);
}
int query(int a,int b)
{
int res=0;
while(tr[a].top!=tr[b].top){
if(ctop(a).dep<ctop(b).dep)
swap(a,b);
res=(res+Segment_tree::query(ctop(a).id,tr[a].id))%Mod;
a=ctop(a).fa;
}
if(tr[a].dep>tr[b].dep)
swap(a,b);
return (res+Segment_tree::query(tr[a].id,tr[b].id))%Mod;
}
void update(int a,int b,int k)
{
k%=Mod;
while(tr[a].top!=tr[b].top){
if(ctop(a).dep<ctop(b).dep)
swap(a,b);
Segment_tree::update(ctop(a).id,tr[a].id,k);
a=ctop(a).fa;
}
if(tr[a].dep>tr[b].dep)
swap(a,b);
Segment_tree::update(tr[a].id,tr[b].id,k);
}
}
int main()
{
scanf("%d%d%d%d",&n,&m,&r,&Mod);
for(int i=1;i<=n;i++)
scanf("%d",a+i);
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
Subdivision::init();
while(m--){
scanf("%d%d",&opt,&x);
switch(opt){
case 1:
scanf("%d%d",&y,&z);
Subdivision::update(x,y,z);
break;
case 2:
scanf("%d",&y);
printf("%d\n",Subdivision::query(x,y));
break;
case 3:
scanf("%d",&z);
Subdivision::update(x,z);
break;
case 4:
printf("%d\n",Subdivision::query(x));
}
}
return 0;
}