2568. 树链剖分

题目链接

2568. 树链剖分

给定一棵树,树中包含 \(n\) 个节点(编号 \(1 \sim n\) ),其中第 \(i\) 个节点的权值为 \(a_{i}\)
初始时, \(1\) 号节点为树的根节点。
现在要对该树进行 \(m\) 次操作,操作分为以下 4 种类型:

  • 1 u v k ,修改路径上节点权值,将节点 \(u\) 和节点 \(v\) 之间路径上的所有节点(包括这两个节点)的权值增加 \(k\)
  • 2 u k ,修改子树上节点权值,将以节点 \(u\) 为根的子树上的所有节点的权值增加 \(k\)
  • 3 u v ,询问路径,询问节点 \(u\) 和节点 \(v\) 之间路径上的所有节点(包括这两个节点)的权值和。
  • 4 u ,询问子树,询问以节点 \(u\) 为根的子树上的所有节点的权值和。

输入格式

第一行包含一个整数 \(n\) ,表示节点个数。
第二行包含 \(n\) 个整数,其中第 \(i\) 个整数表示 \(a_{i}\)
接下来 \(n-1\) 行,每行包含两个整数 \(x, y\) ,表示节点 \(x\) 和节点 \(y\) 之间存在一条边。
再一行包含一个整数 \(m\) ,表示操作次数。
接下来 \(m\) 行,每行包含一个操作,格式如题目所述。

输出格式

对于每个操作 3 和操作 4 ,输出一行一个整数表示答案。

数据范围

$1 \leq n, m \leq 10^{5} $
$0 \leq a_{i}, k \leq 10^{5} $
\(1 \leq u, v, x, y \leq n\)

输入样例:

5
1 3 7 4 5
1 3
1 4
1 5
2 3
5
1 3 4 3
3 5 4
1 3 5 10
2 3 5
4 1

输出样例:

16
69

解题思路

树链剖分

树链剖分是一种思想而不是一种数据结构,其关键在于:

  • 将树转化为一段区间

  • 将树上的任意一条路径转化为 \(O(logn)\) 段区间

这样,在树上的操作等价于在区间上的操作,而区间上的操作可以利用一些数据结构进行

其次关键在于如何转化为一段区间:先求出重儿子(节点个数最多的儿子,且每个有儿子节点的父节点有且仅有一个重儿子),然后按 \(dfs\) 序优先遍历重儿子的原则给节点重新编号,进而求出一段连续区间

对子树进行修改和查询,由于 \(dfs\) 遍历方式的原因,子树中的编号都是连续的,相当于对一段区间进行修改和查询

对路径进行修改和查询,类似于 LCA 的思想,每次提取最低的重链构成的区间,然后对这些区间进行修改和查询

  • 时间复杂度:\(O(mlog^2n)\)

代码

// Problem: 树链剖分
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/2570/
// Memory Limit: 64 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

// %%%Skyqwq
#include <bits/stdc++.h>
 
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
 
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
 
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
 
template <typename T> void inline read(T &x) {
    int f = 1; x = 0; char s = getchar();
    while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
    while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
    x *= f;
}

const int N=1e5+5;
int n,w[N],nw[N],sz[N],dep[N],fa[N],top[N],son[N],id[N],cnt;
struct T
{
	int l,r;
	LL add,sum;
}tr[N*4];
vector<int> adj[N];
void dfs(int x,int father,int depth)
{
	sz[x]=1,fa[x]=father,dep[x]=depth;
	for(int y:adj[x])
	{
		if(y==father)continue;
		dfs(y,x,depth+1);
		sz[x]+=sz[y];
		if(sz[son[x]]<sz[y])son[x]=y;
	}
}
void dfs1(int x,int t)
{
	id[x]=++cnt,nw[cnt]=w[x],top[x]=t;
	if(!son[x])return ;
	dfs1(son[x],t);
	for(int y:adj[x])
	{
		if(y==fa[x]||y==son[x])continue;
		dfs1(y,y);
	}
}
void pushup(int u)
{
	tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
}
void pushdown(int u)
{
	if(tr[u].add)
	{
		tr[u<<1].add+=tr[u].add,tr[u<<1|1].add+=tr[u].add;
		tr[u<<1].sum+=(tr[u<<1].r-tr[u<<1].l+1)*tr[u].add;
		tr[u<<1|1].sum+=(tr[u<<1|1].r-tr[u<<1|1].l+1)*tr[u].add;
		tr[u].add=0;
	}
}
void build(int u,int l,int r)
{
	tr[u].l=l,tr[u].r=r;
	if(l==r)
	{
		tr[u].sum=nw[l];
		return ;
	}
	int mid=l+r>>1;
	build(u<<1,l,mid),build(u<<1|1,mid+1,r);
	pushup(u);
}
void update(int u,int l,int r,int k)
{
	if(l<=tr[u].l&&tr[u].r<=r)
	{
		tr[u].sum+=k*(tr[u].r-tr[u].l+1);
		tr[u].add+=k;
		return ;
	}
	pushdown(u);
	int mid=tr[u].l+tr[u].r>>1;
	if(l<=mid)update(u<<1,l,r,k);
	if(r>mid)update(u<<1|1,l,r,k);
	pushup(u);
}
LL ask(int u,int l,int r)
{
	if(l<=tr[u].l&&tr[u].r<=r)return tr[u].sum;
	pushdown(u);
	LL res=0;
	int mid=tr[u].l+tr[u].r>>1;
	if(l<=mid)res+=ask(u<<1,l,r);
	if(r>mid)res+=ask(u<<1|1,l,r);
	return res;
}
void update_path(int u,int v,int k)
{
	while(top[u]!=top[v])
	{
		if(dep[top[u]]<dep[top[v]])swap(u,v);
		update(1,id[top[u]],id[u],k);
		u=fa[top[u]];
	}
	if(dep[u]>dep[v])swap(u,v);
	update(1,id[u],id[v],k);
}
LL ask_path(int u,int v)
{
	LL res=0;
	while(top[u]!=top[v])
	{
		if(dep[top[u]]<dep[top[v]])swap(u,v);
		res+=ask(1,id[top[u]],id[u]);
		u=fa[top[u]];
	}
	if(dep[u]>dep[v])swap(u,v);
	res+=ask(1,id[u],id[v]);
	return res;
}
void update_tree(int u,int k)
{
	update(1,id[u],id[u]+sz[u]-1,k);
}
LL ask_tree(int u)
{
	return ask(1,id[u],id[u]+sz[u]-1);
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)scanf("%d",&w[i]);
    for(int i=1;i<n;i++)
    {
    	int x,y;
    	scanf("%d%d",&x,&y);
    	adj[x].pb(y),adj[y].pb(x);
    }
    dfs(1,-1,1);
    dfs1(1,1);
    build(1,1,n);
    int op,u,v,k,m;
    scanf("%d",&m);
    while(m--)
    {
    	scanf("%d%d",&op,&u);
    	if(op==1)
    	{
    		scanf("%d%d",&v,&k);
    		update_path(u,v,k);
    	}
    	else if(op==2)
    	{
    		scanf("%d",&k);
    		update_tree(u,k);	
    	}
    	else if(op==3)
    	{
    		scanf("%d",&v);
    		printf("%lld\n",ask_path(u,v));
    	}
    	else
    		printf("%lld\n",ask_tree(u));
    }
    return 0;
}
posted @ 2022-04-18 13:03  zyy2001  阅读(37)  评论(0编辑  收藏  举报