树剖模板
把一棵树拆成若干个不相交的链,然后用一些数据结构去维护这些链
重链剖分
对于每一个节点,找出它的重儿子,那么这棵树就自然而然的被拆成了许多重链与许多轻链
如何对这些链进行维护?
首先,要对这些链进行维护,就要确保每个链上的节点都是连续的
因此我们需要对整棵树进行重新编号,然后利用 dfs 序的思想,用线段树或树状数组等进行维护
注意在进行重新编号的时候先访问重链,这样可以保证重链内的节点编号连续
Code Luogu P3384:
#include<bits/stdc++.h>
#define int long long
#define ll long long
#define fd(i,a,b) for(int i=a,_i=b;i<=_i;i=-~i)
#define bd(i,a,b) for(int i=a,_i=b;i>=_i;i=~-i)
using namespace std;
inline int read(){int x;scanf("%lld",&x);return x;}
inline void write(int x,int F=1)
{
if(F==0) printf("%lld ",x);
else if(F==1) printf("%lld\n",x);
else printf("%lld",x);
}
const int N=1e6+509;
int w[N],n,m,mod,Root;
int wnew[N];
//w[]、wt[]初始点权数组
vector<int> e[N];
struct ST
{
struct node
{
ll l,r,add,sum,mul;
#define l(x) st[x].l
#define r(x) st[x].r
#define add(x) st[x].add
#define sum(x) st[x].sum
#define mul(x) st[x].mul
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
}st[N<<1];
inline void pushdown(int p)
{
if(add(p)==0&&mul(p)==1) return;
sum(ls(p))=(sum(ls(p))*mul(p)%mod+add(p)*(r(ls(p))-l(ls(p))+1)%mod)%mod;
sum(rs(p))=(sum(rs(p))*mul(p)%mod+add(p)*(r(rs(p))-l(rs(p))+1)%mod)%mod;
mul(ls(p))=mul(ls(p))*mul(p)%mod;
mul(rs(p))=mul(rs(p))*mul(p)%mod;
add(ls(p))=(add(ls(p))*mul(p)%mod+add(p))%mod;
add(rs(p))=(add(rs(p))*mul(p)%mod+add(p))%mod;
add(p)=0;
mul(p)=1;
}
inline void pushup(int p)
{
sum(p)=sum(ls(p))+sum(rs(p));
sum(p)%=mod;
}
void Build(int p,int l,int r)
{
l(p)=l,r(p)=r;
add(p)=0,mul(p)=1;
if(l==r)
{
sum(p)=wnew[r]%mod;
return;
}
int mid=((r-l)>>1)+l;
Build(ls(p),l,mid);
Build(rs(p),mid+1,r);
pushup(p);
}
void Add(int p,int l,int r,ll v)
{
if(l(p)>=l&&r(p)<=r)
{
add(p)=(add(p)+v)%mod;
sum(p)=(sum(p)+v*(r(p)-l(p)+1)%mod)%mod;
return;
}
pushdown(p);
int mid=((r(p)-l(p))>>1)+l(p);
if(l<=mid) Add(ls(p),l,r,v);
if(mid<r) Add(rs(p),l,r,v);
pushup(p);
}
void Mul(int p,int l,int r,ll v)
{
if(l(p)>=l&&r(p)<=r)
{
add(p)=(add(p)*v)%mod;
mul(p)=(mul(p)*v)%mod;
sum(p)=(sum(p)*v)%mod;
return;
}
pushdown(p);
int mid=((r(p)-l(p))>>1)+l(p);
if(l<=mid) Mul(ls(p),l,r,v);
if(mid<r) Mul(rs(p),l,r,v);
pushup(p);
}
ll Ask(int p,int l,int r)
{
if(l(p)>=l&&r(p)<=r) return sum(p);
pushdown(p);
int mid=((r(p)-l(p))>>1)+l(p);
ll res=0;
if(l<=mid) res+=Ask(ls(p),l,r)%mod;
if(mid<r) res+=Ask(rs(p),l,r)%mod;
return res%mod;
}
}St;//线段树板子
struct TreeCut
{
int top[N],d[N],fa[N],son[N],siz[N],id[N],cnt=0;
//son[]重儿子编号,id[]新编号,fa[]父亲节点,cnt dfs_clock/dfs序,dep[]深度,siz[]子树大小,top[]当前链顶端节点
inline int qRange(int x,int y)
{
int res=0;
while(top[x]!=top[y])//当两个点不在同一条链上
{
if(d[top[x]]<d[top[y]]) swap(x,y);//把x点改为所在链顶端的深度更深的那个点
res+=St.Ask(1,id[top[x]],id[x]);//res 这一段区间的点权和
res%=mod,x=fa[top[x]];//把x跳到x所在链顶端的那个点的上面一个点
}
//直到两个点处于一条链上
if(d[x]>d[y]) swap(x,y);//把x点深度更深的那个点
res+=St.Ask(1,id[x],id[y]);//这时再加上此时两个点的区间和即可
return res%mod;
}
inline void upRange(int x,int y,int k)//同上
{
k%=mod;
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]]) swap(x,y);
St.Add(1,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
St.Add(1,id[x],id[y],k);
}
inline int qSon(int x) {return St.Ask(1,id[x],id[x]+siz[x]-1);}
//子树区间右端点为id[x]+siz[x]-1
inline void upSon(int x,int k) {St.Add(1,id[x],id[x]+siz[x]-1,k%mod);}//同上
inline void dfs1(int x,int fat,int deep)//x当前节点,fat父亲,deep深度
{
d[x]=deep,fa[x]=fat,siz[x]=1;//标记每个点的深度,父亲,子树大小
int Hson=-1;//记录重儿子的儿子数
for(int y:e[x])
{
if(y==fat) continue;
dfs1(y,x,deep+1);
siz[x]+=siz[y];//把它的儿子数加到它身上
if(siz[y]>Hson) son[x]=y,Hson=siz[y];//求重儿子
}
}
inline void dfs2(int x,int Top)//x当前节点,Top当前链的最顶端的节点
{
id[x]=++cnt,wnew[cnt]=w[x],top[x]=Top;
//标记每个点的新编号,把每个点的初始值赋到新编号上来,这个点所在链的顶端
if(!son[x]) return;
dfs2(son[x],Top);;//按先处理重儿子,再处理轻儿子的顺序递归处理
for(int y:e[x])
{
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);//对于每一个轻儿子都有一条从它自己开始的链
}
}
}Tc;
signed main()
{
#define FJ
#ifndef FJ
freopen("bzoj1066_.in","r",stdin);
freopen("bzoj1066_.out","w",stdout);
#endif
n=read(),m=read(),Root=read(),mod=read();
fd(i,1,n) w[i]=read();
fd(i,1,n-1)
{
int X=read(),Y=read();
e[X].push_back(Y);
e[Y].push_back(X);
}
Tc.dfs1(Root,0,1);
Tc.dfs2(Root,Root);
St.Build(1,1,n);
DO:m=~-m;
int k=read(),x,y,z;
if(k==1)
{
x=read(),y=read(),z=read();
Tc.upRange(x,y,z);
}
else if(k==2)
{
x=read(),y=read();
write(Tc.qRange(x,y));
}
else if(k==3)
{
x=read(),y=read();
Tc.upSon(x,y);
}
else
{
x=read();
write(Tc.qSon(x));
}
if(m) goto DO;
return 0;
}
两年后的码:
#include<bits/stdc++.h>
#define int long long
#define ll long long
#define fd(i,a,b) for(int i=(a);i<=(b);i=-~i)
#define bd(i,a,b) for(int i=(a);i>=(b);i=~-i)
#define db(x) cout<<"DEBUG "<<#x<<" = "<<x<<endl;
using namespace std;
const int N=1e5+509,M=5e5+509,Mod=998244353;
int n,m,mod,rp;
int val[N],v[N];
int son[N],fa[N],id[N],tot,dep[N],siz[N],top[N];
vector<int> e[N];
//--------------线段树-----------------
struct st
{
int l,r,sum,add;
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
#define sum(p) (t[p].sum)
#define add(p) (t[p].add)
#define len(p) (t[p].r-t[p].l+1)
}t[N<<2];
inline void pushup(int p)
{
sum(p)=(sum(ls(p))+sum(rs(p)))%mod;
}
inline void pushdown(int p)
{
if(!add(p)) return;
(add(ls(p))+=add(p))%=mod;
(add(rs(p))+=add(p))%=mod;
(sum(ls(p))+=add(p)*len(ls(p))%mod)%=mod;
(sum(rs(p))+=add(p)*len(rs(p))%mod)%=mod;
add(p)=0;
}
#define mid(p) (((t[p].r-t[p].l)>>1)+t[p].l)
void build(int p,int l,int r)
{
t[p].l=l,t[p].r=r,add(p)=0,sum(p)=0;
if(l==r) {sum(p)=val[l]%mod;return;}
build(ls(p),l,mid(p));
build(rs(p),mid(p)+1,r);
pushup(p);
}
int ask(int p,int l,int r)
{
if(l<=t[p].l&&t[p].r<=r) return sum(p);
pushdown(p);
int res=0;
if(l<=mid(p)) res+=ask(ls(p),l,r)%mod;
if(r>mid(p)) res+=ask(rs(p),l,r)%mod;
pushup(p);
return res%mod;
}
void change(int p,int l,int r,int k)
{
if(l<=t[p].l&&t[p].r<=r)
{
(add(p)+=k)%=mod;
(sum(p)+=k*len(p)%mod)%=mod;
return;
}
pushdown(p);
if(l<=mid(p)) change(ls(p),l,r,k);
if(r>mid(p)) change(rs(p),l,r,k);
pushup(p);
}
//---------------END------------------
//--------------树 剖-----------------
//找重儿子
void getson(int x,int fat,int d)
{
dep[x]=d;fa[x]=fat;siz[x]=1;
int maxx=-1;
for(auto y:e[x])
{
if(y==fa[x]) continue;
getson(y,x,d+1);
siz[x]+=siz[y];
if(siz[y]>maxx)
{
son[x]=y;
maxx=siz[y];
}
}
}
//找链顶+dfs序
void gettop(int x,int tp)
{
id[x]=++tot;val[id[x]]=v[x];top[x]=tp;
if(!son[x]) return;gettop(son[x],tp);
for(auto y:e[x])
{
if(y==fa[x]||y==son[x]) continue;
gettop(y,y);
}
}
//查询x~y路径和
inline int askR(int x,int y)
{
int res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
(res+=ask(1,id[top[x]],id[x]))%=mod;
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
(res+=ask(1,id[x],id[y]))%=mod;
return res%mod;
}
//修改x~y路径
inline int updR(int x,int y,int k)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
change(1,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
change(1,id[x],id[y],k);
}
//查询x子树和
inline int askS(int x)
{
return ask(1,id[x],id[x]+siz[x]-1);
}
//修改x子树和
inline void updS(int x,int k)
{
change(1,id[x],id[x]+siz[x]-1,k);
}
//---------------END------------------
signed main()
{
// #define FJ
#ifdef FJ
freopen(".in","r",stdin);
freopen(".out","w",stdout);
#else
// freopen("A.in","r",stdin);
// freopen("A.out","w",stdout);
#endif
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m>>rp>>mod;
fd(i,1,n) cin>>v[i];
fd(i,1,n-1)
{
int x,y;cin>>x>>y;
e[x].push_back(y);
e[y].push_back(x);
}
getson(rp,0,1);
gettop(rp,rp);
build(1,1,n);
while(m--)
{
int k,x,y,z;
cin>>k;
if(k==1)
{
cin>>x>>y>>z;
updR(x,y,z);
}
else if(k==2)
{
cin>>x>>y;
cout<<askR(x,y)<<endl;
}
else if(k==3)
{
cin>>x>>y;
updS(x,y);
}
else
{
cin>>x;
cout<<askS(x)<<endl;
}
}
return 0;
}
本文来自博客园,作者:whrwlx,转载请注明原文链接:https://www.cnblogs.com/whrwlx/p/18306072