点分树(动态点分治)学习笔记

1. 定义

在点分治的基础上加以变化,构造一颗支持快速修改的重构树,称之为点分树

2.算法

2.1. 思路

点分治的核心在于通过树的重心来划分联通块,减少合并层数,从而降低时间复杂度

所以,我们可以按分治递归的顺序提出一颗树,易知树高至多为logn

具体的说,对于每一个找到的重心,将上一次分治时的重心设为其父亲,就可以得到一颗logn层的虚树(重构树)

举个例子,原树为:

新树为:

此时,有一个性质,所有子树的子树大小之和为nlogn

证明:每个点会被从根到它的路径上最多logn个祖先所统计,所以必然小于nlogn

所以在新树上修改,只需要暴力儿子跳父亲即可

2.2. 应用

统计一个点到其他点的点权和,即\(\sum_{y=1}^n dis(x,y)\),对于任意一个y,找到它与x在虚树上的lca,易知在以此点为重心划分子连通块时x,y会首次被分割开来,因此该点必定在原树的x,y路径上。

所以我们只需要在这些lca的虚子树中寻找y即可,此时记录虚子树信息的作用便显现出来了。

而对于一个x,可能的lca最多存在logn个,因此通常使用暴力枚举+简单容斥的方法来统计y的贡献。

3.具体实现

3.1. 例题:P6329 【模板】点分树 | 震波

oj:https://gxyzoj.com/d/gxyznoi/p/P17

题意:维护一颗带点权树,需要支持两种操作:修改x的点权,查询与点x距离不超过k的点权值之和。

3.2. 思路

3.2.1. 建树

在找到第一个重心rt后,先遍历得到整颗树的信息,然后删除rt,再递归处理其他联通块

操作和点分治基本相同,但是将统计答案变为统计子树信息

3.2.2. 计算贡献

\(a_i\)为点i的权值,\(f_{i,j}\)为与i距离不超过j的点的权值之和,则:

\[f_{u,i}=\sum_{v\in subtree(u),dis(u,v)\le i} a_v \]

为去重,还需记录i子树内与fa距离不超过j的点的权值之和,即:

\[g_{u,i}=\sum_{v\in subtree(u),dis(fa,v)\le i} a_v \]

在一次查询(x,k)中,对于虚树上的一对父子节点\((u,fa)\)\(subtree(fa)-subtree(u)\)的贡献\(sum_{u,fa}=f_{fa,k-dis(x,fa)}-g_{u,k-dis(x,fa)}\)

则答案为\(ans=f_{x,k}+\sum sum_{u,fa}\)

用线段树维护即可

3.3. 代码

#include<cstdio>
#include<algorithm>
using namespace std;
int n,m,a[100005],head[100005],edgenum;
struct edge{
	int to,nxt,val;
}e[200005];
void add_edge(int u,int v)
{
	e[++edgenum].nxt=head[u];
	e[edgenum].to=v;
	head[u]=edgenum;
}
bool vis[100005];
int sum,size[100005],mx[100005],tot,p[100005];
void dfs1(int u,int fa)
{
	p[++tot]=u;
	size[u]=1,mx[u]=0;
	for(int i=head[u];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(vis[v]||v==fa) continue;
		dfs1(v,u);
		size[u]+=size[v];
		mx[u]=max(mx[u],size[v]);
	}
}
int dep[100005],f[100005][21];
void dfs2(int u,int fa)
{
	//printf("%d %d\n",u,fa);
	dep[u]=dep[fa]+1;
	f[u][0]=fa;
	for(int i=1;i<=20;i++)
	{
		f[u][i]=f[f[u][i-1]][i-1];
	}
	for(int i=head[u];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(v==fa) continue;
		dfs2(v,u);
	}
}
int lca(int x,int y)
{
	if(dep[x]<dep[y]) swap(x,y);
	for(int i=20;i>=0;i--)
	{
		if(dep[f[x][i]]>=dep[y])
		{
			x=f[x][i];
		}
		if(x==y) return x;
	}
	for(int i=20;i>=0;i--)
	{
		if(f[x][i]!=f[y][i])
		{
			x=f[x][i];
			y=f[y][i];
		}
	}
	return f[x][0];
}
int rt1[100005],val1[10000005],ls1[10000005],rs1[10000005],idx1;
int rt2[100005],val2[10000005],ls2[10000005],rs2[10000005],idx2;
int f1[100005];
int getdis(int x,int y)
{
	return dep[x]+dep[y]-2*dep[lca(x,y)];
}
int add1(int id,int l,int r,int x,int v)
{
	if(!id) id=++idx1;
	if(l==r)
	{
		val1[id]+=v;
		return id;
	}
	int mid=(l+r)>>1;
	if(x<=mid) ls1[id]=add1(ls1[id],l,mid,x,v);
	else rs1[id]=add1(rs1[id],mid+1,r,x,v);
	val1[id]=val1[ls1[id]]+val1[rs1[id]];
	return id;
}
int add2(int id,int l,int r,int x,int v)
{
	if(!id) id=++idx2;
	if(l==r)
	{
		val2[id]+=v;
		return id;
	}
	int mid=(l+r)>>1;
	if(x<=mid) ls2[id]=add2(ls2[id],l,mid,x,v);
	else rs2[id]=add2(rs2[id],mid+1,r,x,v);
	val2[id]=val2[ls2[id]]+val2[rs2[id]];
	return id;
}
int query1(int id,int l,int r,int ql,int qr)
{
	if(!id||l>qr||r<ql) return 0;
	if(l>=ql&&r<=qr)
	{
		return val1[id];
	}
	int mid=(l+r)>>1;
	int res=0;
	if(mid>=ql) res=query1(ls1[id],l,mid,ql,qr);
	if(mid<qr) res+=query1(rs1[id],mid+1,r,ql,qr);
	return res;
}
int query2(int id,int l,int r,int ql,int qr)
{
	if(!id||l>qr||r<ql||l>r) return 0;
	if(l>=ql&&r<=qr)
	{
		return val2[id];
	}
	int mid=(l+r)>>1;
	int res=0;
	if(mid>=ql) res=query2(ls2[id],l,mid,ql,qr);
	if(mid<qr) res+=query2(rs2[id],mid+1,r,ql,qr);
	return res;
}
void dfs3(int id,int u,int fa,int len)
{
	rt1[id]=add1(rt1[id],1,n,len,a[u]);
	for(int i=head[u];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(v==fa||vis[v]) continue;
		dfs3(id,v,u,len+1);
	}
}
void dfs(int u,int fa)
{
	tot=0;
	dfs1(u,0);
	mx[0]=1e9;
	for(int i=1;i<=tot;i++)
	{
		mx[p[i]]=max(mx[p[i]],tot-size[p[i]]);
	}
	int root=0;
	for(int i=1;i<=tot;i++)
	{
		if(mx[root]>mx[p[i]]) root=p[i];
	}
	u=root;
	f1[u]=fa;
	vis[u]=1;
	if(fa)
	{
		for(int i=1;i<=tot;i++)
		{
			rt2[u]=add2(rt2[u],1,n,getdis(fa,p[i]),a[p[i]]);
		}
	}
	for(int i=head[u];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(vis[v]) continue;
		dfs3(u,v,u,1);
	}
	for(int i=head[u];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(vis[v]) continue;
		dfs(v,u);
	}
}
void update(int x,int k)
{
	int u=x;
	while(f1[u])
	{
		int d=getdis(x,f1[u]);
		rt2[u]=add2(rt2[u],1,n,d,-a[x]);
		rt2[u]=add2(rt2[u],1,n,d,k);
		u=f1[u];
		rt1[u]=add1(rt1[u],1,n,d,-a[x]);
		rt1[u]=add1(rt1[u],1,n,d,k);
	}
	a[x]=k;
}
int getans(int x,int k)
{
	int u=x;
	int res=0;
	res+=a[x]+query1(rt1[x],1,n,1,k);
	while(f1[u])
	{
		int d=getdis(x,f1[u]);
		if(d<=k) res+=a[f1[u]]+query1(rt1[f1[u]],1,n,1,k-d)-query2(rt2[u],1,n,1,k-d);
		u=f1[u];
	}
	return res;
}
int main()
{
//	freopen("1.txt","r",stdin);
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		scanf("%d",&a[i]);
	}
	for(int i=1;i<n;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		add_edge(u,v);
		add_edge(v,u);
	}
	dfs2(1,0);
	dfs(1,0);
	int ans=0;
	for(int i=1;i<=m;i++)
	{
		int opt,x,y;
		scanf("%d%d%d",&opt,&x,&y);
		x^=ans,y^=ans;
	//	printf("%d %d %d\n",i,x,y);
		if(opt==0)
		{
			ans=getans(x,y);
			printf("%d\n",ans);
		}
		else update(x,y);
	}
	return 0;
}
posted @ 2024-04-17 11:40  wangsiqi2010916  阅读(25)  评论(0编辑  收藏  举报