【洛谷P3384】【模板】树链剖分
题目大意:
题目链接:https://www.luogu.org/problem/P3384
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 表示将树从到结点最短路径上所有节点的值都加上
操作2: 格式: 表示求树从到结点最短路径上所有节点的值之和
操作3: 格式: 表示将以为根节点的子树内所有节点值都加上
操作4: 格式: 表示求以为根节点的子树内所有节点值之和
思路:
树链剖分模板题。
大部分思路、学习过程来自这里:https://www.luogu.org/blog/communist/shu-lian-pou-fen-yang-xie
我们定义如下内容
- 重儿子:一个结点的儿子中,子树最大的儿子
- 轻儿子:该节点除了重儿子以外的儿子
- 重边:重儿子与他父亲的连边
- 轻边:轻儿子与他父亲的连边
- 重链:多条重链连接起来的路径
- 轻链:多条轻边连接起来的路径
之后我们需要进行两次。
第一次我们求出每一个节点的父亲,深度,以及子树大小。分别用记录。
同时还要记录除叶子外每个节点的重儿子。用记录。
然后第二次我们将重链优先编号,这样用数据结构维护时就可以更加方便。同时,我们需要满任意节点的子树编号依然为一段连续的区间。同时记录下每一条重链的起始点,也就是该重链的深度最浅的点,然后记录下每一个点的编号(注意这个编号和序略有不同),以及该编号对应的节点。分别用记录。
void dfs1(int x,int f)
{
fa[x]=f;
dep[x]=dep[fa[x]]+1;
size[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int y=e[i].to;
if (y!=fa[x])
{
dfs1(y,x);
size[x]+=size[y];
if (size[y]>size[son[x]]) son[x]=y;
}
}
}
void dfs2(int x,int tp)
{
top[x]=tp;
id[x]=++cnt;
rk[cnt]=x;
if (son[x]) dfs2(son[x],tp);
for (int i=head[x];~i;i=e[i].next)
{
int y=e[i].to;
if (y!=fa[x] && y!=son[x]) dfs2(y,y);
}
}
然后接下来我们就要处理操作了
1.将树从到结点最短路径上所有节点的值都加上
我们已经保证了每一条重链编号是连续的,所以我们每次在线段树中只要维护若干个区间加。每次选取两点中深度较深的点,然后将该点到该点所在重链的区间加即可,然后把赋值为。
2.求树从到结点最短路径上所有节点的值之和
和操作1的思路是相同的,每次求到其重链的区间和
3.将以为根节点的子树内所有节点值都加上
由于序列只是在序上稍加修改,我们依然可以保证一棵子树任然在同一个区间。
那么如果这棵子树的根的编号为,我们已经处理出了该子树的大小,所以我们要进行区间加的区间为。
4.求以为根节点的子树内所有节点值之和
这个其实就是线段树的模板,区间查询的和即可。
可以证明树链剖分的时间复杂度为
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=100010;
int size[N],fa[N],dep[N],id[N],rk[N],son[N],top[N],a[N],head[N];
int n,m,root,MOD,opt,cnt,tot;
struct edge
{
int next,to;
}e[N*2];
struct Treenode
{
int l,r,sum,lazy;
};
struct Tree
{
Treenode tree[N*4];
int len(int x)
{
return tree[x].r-tree[x].l+1;
}
void pushup(int x)
{
tree[x].sum=(tree[x*2].sum+tree[x*2+1].sum)%MOD;
}
void pushdown(int x)
{
if (tree[x].lazy)
{
tree[x*2].lazy=(tree[x*2].lazy+tree[x].lazy)%MOD;
tree[x*2+1].lazy=(tree[x*2+1].lazy+tree[x].lazy)%MOD;
tree[x*2].sum=(tree[x*2].sum+tree[x].lazy*len(x*2))%MOD;
tree[x*2+1].sum=(tree[x*2+1].sum+tree[x].lazy*len(x*2+1))%MOD;
tree[x].lazy=0;
}
}
void build(int x)
{
if (tree[x].l==tree[x].r)
{
tree[x].sum=a[rk[tree[x].l]]%MOD;
return;
}
int mid=(tree[x].l+tree[x].r)>>1;
tree[x*2].l=tree[x].l;
tree[x*2].r=mid;
tree[x*2+1].l=mid+1;
tree[x*2+1].r=tree[x].r;
build(x*2); build(x*2+1);
pushup(x);
}
void update(int x,int l,int r,int val)
{
if (tree[x].l==l && tree[x].r==r)
{
tree[x].sum=(tree[x].sum+val*len(x))%MOD;
tree[x].lazy=(tree[x].lazy+val)%MOD;
return;
}
pushdown(x);
int mid=(tree[x].l+tree[x].r)>>1;
if (r<=mid) update(x*2,l,r,val);
else if (l>mid) update(x*2+1,l,r,val);
else update(x*2,l,mid,val),update(x*2+1,mid+1,r,val);
pushup(x);
}
void addrange(int x,int y,int k)
{
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,id[top[x]],id[x],k);
x=fa[top[x]];
}
if (id[x]>id[y]) update(1,id[y],id[x],k);
else update(1,id[x],id[y],k);
}
int ask(int x,int l,int r)
{
if (tree[x].l==l && tree[x].r==r) return tree[x].sum;
pushdown(x);
int mid=(tree[x].l+tree[x].r)>>1;
if (r<=mid) return ask(x*2,l,r);
if (l>mid) return ask(x*2+1,l,r);
return (ask(x*2,l,mid)+ask(x*2+1,mid+1,r))%MOD;
}
int askrange(int x,int y)
{
int ans=0;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
ans=(ans+ask(1,id[top[x]],id[x]))%MOD;
x=fa[top[x]];
}
if (id[x]>id[y]) ans=(ans+ask(1,id[y],id[x]))%MOD;
else ans=(ans+ask(1,id[x],id[y]))%MOD;
return ans;
}
}Tree;
void add(int from,int to)
{
e[++tot].to=to;
e[tot].next=head[from];
head[from]=tot;
}
void dfs1(int x,int f)
{
fa[x]=f;
dep[x]=dep[fa[x]]+1;
size[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int y=e[i].to;
if (y!=fa[x])
{
dfs1(y,x);
size[x]+=size[y];
if (size[y]>size[son[x]]) son[x]=y;
}
}
}
void dfs2(int x,int tp)
{
top[x]=tp;
id[x]=++cnt;
rk[cnt]=x;
if (son[x]) dfs2(son[x],tp);
for (int i=head[x];~i;i=e[i].next)
{
int y=e[i].to;
if (y!=fa[x] && y!=son[x]) dfs2(y,y);
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d%d%d",&n,&m,&root,&MOD);
for (int i=1;i<=n;i++)
scanf("%d",&a[i]);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dfs1(root,0);
dfs2(root,root);
Tree.tree[1].l=1; Tree.tree[1].r=n;
Tree.build(1);
for (int i=1,x,y,z;i<=m;i++)
{
scanf("%d",&opt);
if (opt==1)
{
scanf("%d%d%d",&x,&y,&z);
Tree.addrange(x,y,z);
}
if (opt==2)
{
scanf("%d%d",&x,&y);
printf("%d\n",Tree.askrange(x,y));
}
if (opt==3)
{
scanf("%d%d",&x,&y);
Tree.update(1,id[x],id[x]+size[x]-1,y);
}
if (opt==4)
{
scanf("%d",&x);
printf("%d\n",Tree.ask(1,id[x],id[x]+size[x]-1));
}
}
return 0;
}