树链剖分学习笔记

前言

本文涉及算法:线段树、dfs。树链剖分是码量十分巨大的数据结构,但十分有用。

引子

一道来源不明的题:

给一棵树,每个结点都有一个点权\(a_i\),求从\(x\)\(y\)的简单路径上的点权和。

\(10^5\)次询问。

方法一:我会暴力!

强行枚举从\(x\)\(y\)进行求和。

时间复杂度\(O(n^2)\)

那有没有更优秀的算法呢?

方法二:我会树上前缀和!

对于\(ans_i = ans_{fa_i} + a_i\)

查询时输出\(ans_x + ans_y - ans_{lca(x,y)}\)

时间复杂度\(O(nlog_n)\)

(lca(x,y)为x和y的最近公共祖先)

树上前缀和的复杂度已经十分优秀了,完全可以解决静态树上求和。

引子2

又一道来源不明的题:

给一棵树,每个结点都有一个点权\(a_i\),可以进行从\(x\)\(y\)的简单路径上的修改,询问从\(x\)\(y\)的简单路径上的权和。

方法一:我会暴力!

同上。

时间复杂度\(O(n^2)\)

方法二:我会树上前缀和!

对于每一次点权修改重新修改前缀和。

查询方法同上。

当有了修改操作,树上前缀和的时间复杂度貌似变为:

\(O(n^2)!\)

这个方法几乎等同于暴力。

通过仔细打表发现,每次修改树上前缀和的时间太高了,有没有不浪费的算法又资瓷修改操作的呢?

有!我会线段树!

但是我们发现,线段树不资瓷的是树上修改,那我们有没有方法将树劈成一些链来处理呢?

正文

有,树链剖分!

为了行文方便,我们先在这定义一下一些代名词:

重儿子\(son_i\):他的父节点中子树结点数量最多的儿子。

轻儿子 : 他的父节点中不是重儿子的其他儿子。

重链 : 由重儿子连接而成的链。

轻链 :轻儿子组成连接而成的链。

树链剖分有一个主题思想就是将一棵树变成一堆链来线段树。

对于查询操作

我们先找到这棵树的重儿子,将重儿子连成重链,剩下的全部连成轻链,我们记下所有链的顶端,每次求和完成后跳到他的顶端继续游戏。

(注:轻链的顶端是他自己)

其实这个过程就是倍增LCA的过程,每次跳的是深度深的结点,因为防止跳过头。

我们可以用线段树来维护重链的值。

基本变量声明

int fa[200005],son[200005],head[200005],size[200005];
int d[200005],rk[200005];
int top[200005],id[200005],w[200005];
struct E{
	int next,to;
} edge[200005];
struct T{
	int l,r,w,f;
} a[500000];

\(fa_i\) 第i个结点的父亲

\(son_i\) 第i个结点的重儿子

\(head\)\(edge\)为前向星用的数组

\(d_i\) 为第i个结点的深度

\(size_i\) 为第i个结点子树结点的数量

\(top_i\) 为链的顶端

\(id_i\) 为第i个结点在线段树中对应的结点标号

\(rk_i\) 为线段树第i个结点中对应现实结点的标号

\(w_i\)为读入的点权

\(a\)为线段树数组

找到重儿子!

void dfs1(int x)
{
	size[x] = 1;//一开始x的子树结点只有自己
	d[x] = d[fa[x]] + 1;//深度等于他父亲的深度+1
	for (int v,i = head[x]; i; i = edge[i].next)
	if ((v = edge[i].to) != fa[x]) //找到的是他的儿子
	{
		fa[v] = x;//下一个结点的父亲是自己
		dfs1(v);
		size[x] += size[v];//合并子树
		if (size[son[x]] < size[v])
		  son[x] = v; //如果子树节点数比max大,设为重儿子
    } 
}

切树成链!

void dfs2(int x,int tp)
{
	top[x] = tp;//定义顶端
	id[x] = ++sum;
	rk[sum] = x;
	if (son[x])
	   dfs2(son[x],tp);//重儿子优先成重链
	for (int v,i = head[x]; i; i = edge[i].next)
	if ((v = edge[i].to) != fa[x] && v != son[x]) //把轻儿子割出来
	   dfs2(v,v);//轻儿子的top是自己
}

线段树!

void build(int x,int l,int r)
{
	a[x].l = l;
	a[x].r = r;
	if (l == r)
	{
		a[x].w = w[rk[l]];
		if (a[x].w > Mod)
		  a[x].w %= Mod;
		return;
	}
	int mid = (l + r) / 2;
	build(x * 2,l,mid);
	build(x * 2 + 1,mid + 1,r);
	a[x].w = (a[x * 2].w + a[x * 2 + 1].w) % Mod; 
}
void down(int x)
{
    a[x * 2].f += a[x].f;
	a[x * 2 + 1].f += a[x].f;
	a[x * 2].w += a[x].f * (a[x * 2].r - a[x * 2].l + 1) % Mod;
	a[x * 2 + 1].w += a[x].f * (a[x * 2 + 1].r - a[x * 2 + 1].l + 1) % Mod;
	a[x].f = 0;	
}
void change_interval(int k)
{
    if (a[k].l >= as && a[k].r <= bs)
	{
	    a[k].w += g * (a[k].r - a[k].l + 1) % Mod;
	    a[k].f += g;
	    return;
	}	
	if (a[k].f) down(k);
	int mid = (a[k].l + a[k].r) / 2;
	if (as <= mid)
	  change_interval(k * 2);
	if (mid < bs)
	  change_interval(k * 2 + 1);
	a[k].w = (a[k * 2].w + a[k * 2 + 1].w) % Mod; 
}
void ask_interval(int k)
{
    if (a[k].l >= as && a[k].r <= bs)
	{
            ans = (ans + a[k].w) % Mod;
	    return;
	}	
	if (a[k].f) down(k);
	int mid = (a[k].l + a[k].r) / 2;
	if (as <= mid)
	  ask_interval(k * 2);
	if (mid < bs)
	  ask_interval(k * 2 + 1); 
}

//不解释。

链操作!

int Si(int x,int y)//输入从x到y的简单路径上的值
{
	ans = 0;
	while (top[x] != top[y])//没有在一起
	{
		if (d[top[x]] < d[top[y]])//我们保证x深度深
		  swap(x,y);
		as = id[top[x]];
		bs = id[x];
		ask_interval(1);//左区间为链顶,右区间为链尾
		x = fa[top[x]];//继续向上
	}
	if (id[x] > id[y])
	  swap(x,y);
	as = id[x];
	bs = id[y];//同理
	ask_interval(1);
	return ans % Mod;
}//链询问
int ts(int x,int y)
{
	while (top[x] != top[y])
	{
		if (d[top[x]] < d[top[y]])
		  swap(x,y);
		as = id[top[x]];
		bs = id[x];
		change_interval(1);
		x = fa[top[x]];
	}
	if (id[x] > id[y])
	  swap(x,y);
	as = id[x];
	bs = id[y];
//这里同查询
}//链修改

合并以上代码

只要将以上代码合并,

我们可以在\(O(nlog_n)\)的复杂度内A掉在引子2出现的例题了。

总代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,fa[200005],son[200005],edgenum,head[200005],size[200005];
int d[200005],rk[200005],g,ans,as,bs,r,Mod,sum;
int top[200005],id[200005],w[200005];
struct E{
	int next,to;
} edge[200005];
struct T{
	int l,r,w,f;
} a[500000];
void ins(int x,int y)
{ 
    edge[++edgenum].to = y;
	edge[edgenum].next = head[x];
	head[x] = edgenum;	
}
void dfs1(int x)
{
	size[x] = 1;
	d[x] = d[fa[x]] + 1;
	for (int v,i = head[x]; i; i = edge[i].next)
	if ((v = edge[i].to) != fa[x]) 
	{
		fa[v] = x;
		dfs1(v);
		size[x] += size[v];
		if (size[son[x]] < size[v])
		  son[x] = v; 
    } 
}
void dfs2(int x,int tp)
{
	top[x] = tp;
	id[x] = ++sum;
	rk[sum] = x;
	if (son[x])
	   dfs2(son[x],tp);
	for (int v,i = head[x]; i; i = edge[i].next)
	if ((v = edge[i].to) != fa[x] && v != son[x]) 
	   dfs2(v,v);
}
void build(int x,int l,int r)
{
	a[x].l = l;
	a[x].r = r;
	if (l == r)
	{
		a[x].w = w[rk[l]];
		if (a[x].w > Mod)
		  a[x].w %= Mod;
		return;
	}
	int mid = (l + r) / 2;
	build(x * 2,l,mid);
	build(x * 2 + 1,mid + 1,r);
	a[x].w = (a[x * 2].w + a[x * 2 + 1].w) % Mod; 
}
void down(int x)
{
    a[x * 2].f += a[x].f;
	a[x * 2 + 1].f += a[x].f;
	a[x * 2].w += a[x].f * (a[x * 2].r - a[x * 2].l + 1) % Mod;
	a[x * 2 + 1].w += a[x].f * (a[x * 2 + 1].r - a[x * 2 + 1].l + 1) % Mod;
	a[x].f = 0;	
}
void change_interval(int k)
{
    if (a[k].l >= as && a[k].r <= bs)
	{
		//cout<<g<<endl;
	    a[k].w += g * (a[k].r - a[k].l + 1) % Mod;
	    a[k].f += g;
	    return;
	}	
	if (a[k].f) down(k);
	int mid = (a[k].l + a[k].r) / 2;
	if (as <= mid)
	  change_interval(k * 2);
	if (mid < bs)
	  change_interval(k * 2 + 1);
	a[k].w = (a[k * 2].w + a[k * 2 + 1].w) % Mod; 
}
void ask_interval(int k)
{
    if (a[k].l >= as && a[k].r <= bs)
	{
        ans = (ans + a[k].w) % Mod;
        //cout<<"Test:"<<ans<<endl;
	    return;
	}	
	if (a[k].f) down(k);
	int mid = (a[k].l + a[k].r) / 2;
	if (as <= mid)
	  ask_interval(k * 2);
	if (mid < bs)
	  ask_interval(k * 2 + 1); 
}
int Si(int x,int y)
{
	ans = 0;
	while (top[x] != top[y])
	{
		if (d[top[x]] < d[top[y]])
		  swap(x,y);
		as = id[top[x]];
		bs = id[x];
		ask_interval(1);
		//cout<<ans<<endl;
		//ans = ans % Mod;
		x = fa[top[x]];
	}
	if (id[x] > id[y])
	  swap(x,y);
	as = id[x];
	bs = id[y];
	ask_interval(1);
	//cout<<ans<<endl;
	return ans % Mod;
}
int ts(int x,int y)
{
	while (top[x] != top[y])
	{
		if (d[top[x]] < d[top[y]])
		  swap(x,y);
		as = id[top[x]];
		bs = id[x];
		change_interval(1);
		x = fa[top[x]];
	}
	if (id[x] > id[y])
	  swap(x,y);
	as = id[x];
	bs = id[y];
	change_interval(1);
	//cout<<g<<endl;
}
signed main(){
	scanf("%lld%lld",&n,&m);
	Mod = 1000000008;
	for (int i = 1; i <= n; i++)
	  scanf("%lld",&w[i]);
	for (int i = 1; i < n; i++)
	{
		int x,y;
		scanf("%lld%lld",&x,&y);
		ins(x,y);
		ins(y,x);
	}
	dfs1(1);
	dfs2(1,1);
	build(1,1,n);
	for (int i = 1; i <= m; i++)
	{
		int op,x,y;
		scanf("%lld",&op);
		if (op == 1)
		{
			scanf("%lld%lld%lld",&x,&y,&g);
			ts(x,y);
		} else
		if (op == 2)
		{
			scanf("%lld%lld",&x,&y);
			printf("%lld\n",Si(x,y) % Mod);
		}
	}
	return 0;
}
 

luogu树链剖分模板

对于子树操作其实很好求:

将左区间定义为x,右区间定义为x+size[x]-1就ok了

因为这一段在线段树上已经是一段完整的区间了。

上代码:


#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,fa[200005],son[200005],edgenum,head[200005],size[200005];
int d[200005],rk[200005],g,ans,as,bs,r,Mod,sum;
int top[200005],id[200005],w[200005];
struct E{
	int next,to;
} edge[200005];
struct T{
	int l,r,w,f;
} a[500000];
void ins(int x,int y)
{ 
        edge[++edgenum].to = y;
	edge[edgenum].next = head[x];
	head[x] = edgenum;	
}
void dfs1(int x)
{
	size[x] = 1;
	d[x] = d[fa[x]] + 1;
	for (int v,i = head[x]; i; i = edge[i].next)
	if ((v = edge[i].to) != fa[x]) 
	{
		fa[v] = x;
		dfs1(v);
		size[x] += size[v];
		if (size[son[x]] < size[v])
		  son[x] = v; 
    } 
}
void dfs2(int x,int tp)
{
	top[x] = tp;
	id[x] = ++sum;
	rk[sum] = x;
	if (son[x])
	   dfs2(son[x],tp);
	for (int v,i = head[x]; i; i = edge[i].next)
	if ((v = edge[i].to) != fa[x] && v != son[x]) 
	   dfs2(v,v);
}
void build(int x,int l,int r)
{
	a[x].l = l;
	a[x].r = r;
	if (l == r)
	{
		a[x].w = w[rk[l]];
		if (a[x].w > Mod)
		  a[x].w %= Mod;
		return;
	}
	int mid = (l + r) / 2;
	build(x * 2,l,mid);
	build(x * 2 + 1,mid + 1,r);
	a[x].w = (a[x * 2].w + a[x * 2 + 1].w) % Mod; 
}
void down(int x)
{
        a[x * 2].f += a[x].f;
	a[x * 2 + 1].f += a[x].f;
	a[x * 2].w += a[x].f * (a[x * 2].r - a[x * 2].l + 1) % Mod;
	a[x * 2 + 1].w += a[x].f * (a[x * 2 + 1].r - a[x * 2 + 1].l + 1) % Mod;
	a[x].f = 0;	
}
void change_interval(int k)
{
    if (a[k].l >= as && a[k].r <= bs)
	{
	    a[k].w += g * (a[k].r - a[k].l + 1) % Mod;
	    a[k].f += g;
	    return;
	}	
	if (a[k].f) down(k);
	int mid = (a[k].l + a[k].r) / 2;
	if (as <= mid)
	  change_interval(k * 2);
	if (mid < bs)
	  change_interval(k * 2 + 1);
	a[k].w = (a[k * 2].w + a[k * 2 + 1].w) % Mod; 
}
void ask_interval(int k)
{
    if (a[k].l >= as && a[k].r <= bs)
	{
        ans = (ans + a[k].w) % Mod;
	    return;
	}	
	if (a[k].f) down(k);
	int mid = (a[k].l + a[k].r) / 2;
	if (as <= mid)
	  ask_interval(k * 2);
	if (mid < bs)
	  ask_interval(k * 2 + 1); 
}
int Si(int x,int y)
{
	ans = 0;
	while (top[x] != top[y])
	{
		if (d[top[x]] < d[top[y]])
		  swap(x,y);
		as = id[top[x]];
		bs = id[x];
		ask_interval(1);
		x = fa[top[x]];
	}
	if (id[x] > id[y])
	  swap(x,y);
	as = id[x];
	bs = id[y];
	ask_interval(1);
	return ans % Mod;
}
int ts(int x,int y)
{
	while (top[x] != top[y])
	{
		if (d[top[x]] < d[top[y]])
		  swap(x,y);
		as = id[top[x]];
		bs = id[x];
		change_interval(1);
		x = fa[top[x]];
	}
	if (id[x] > id[y])
	  swap(x,y);
	as = id[x];
	bs = id[y];
	change_interval(1);
}
signed main(){
	scanf("%lld%lld%lld%lld",&n,&m,&r,&Mod);
	for (int i = 1; i <= n; i++)
	  scanf("%lld",&w[i]);
	for (int i = 1; i < n; i++)
	{
		int x,y;
		scanf("%lld%lld",&x,&y);
		ins(x,y);
		ins(y,x);
	}
	dfs1(r);
	dfs2(r,r);
	build(1,1,n);
	for (int i = 1; i <= m; i++)
	{
		int op,x,y;
		scanf("%lld",&op);
		if (op == 1)
		{
			scanf("%lld%lld%lld",&x,&y,&g);
			ts(x,y);
		} else
		if (op == 2)
		{
			scanf("%lld%lld",&x,&y);
			printf("%lld\n",Si(x,y) % Mod);
		} else
		if (op == 3)
		{
			scanf("%lld%lld",&x,&y);
			as = id[x];
			bs = id[x] + size[x] - 1;
			g = y;
			change_interval(1);
		} else
		if (op == 4)
		{
			ans = 0;
			scanf("%lld",&x);
			as = id[x];
			bs = id[x] + size[x] - 1;
			ask_interval(1);
			printf("%lld\n",ans % Mod);
		} 
	}
	return 0;
}
 

[NOI2015]软件包管理器

链查询,子树修改。

在查询完毕后别忘了修改链。

代码:


#include<bits/stdc++.h>
#define int long long
using namespace std;
int head[100005],rk[100005],top[100005],id[100005],size[100005];
int son[100005],n,m,edgenum,d[100005],fa[100005],sum,ans;
int as,bs,insert[100005],cnt,anss,g;
struct E{
    int next,to;
} edge[300005];
struct T{
    int l,r,w,f;
} a[400005]; 
void ins(int x,int y)
{
    edge[++edgenum].to = y;
    edge[edgenum].next = head[x];
    head[x] = edgenum;
}
void dfs1(int x)
{
    size[x] = 1;
    d[x] = d[fa[x]] + 1;
    for (int v,i = head[x]; i; i = edge[i].next)
    if ((v = edge[i].to) != fa[x])
    {
        fa[v] = x;
        dfs1(v);
        size[x] += size[v];
        if (size[son[x]] < size[v] || !son[x])
          son[x] = v;
    }
}
void dfs2(int x,int tp)
{
    top[x] = tp;
    id[x] = ++sum;
    rk[sum] = x;
    if (son[x])
      dfs2(son[x],tp);
    for (int v,i = head[x]; i; i = edge[i].next)
    if ((v = edge[i].to) != fa[x] && v != son[x])
      dfs2(v,v);
}
void build(int x,int l,int r)
{
    a[x].l = l;
    a[x].r = r;
    if (l == r)
    {
       a[x].f = -1;
       return;
    }
    int mid = (l + r) / 2;
    build(x * 2,l,mid);
    build(x * 2 + 1,mid + 1,r);
}
void down(int x)
{
    a[x * 2].f = a[x].f;
    a[x * 2 + 1].f = a[x].f;
    a[x * 2].w = a[x].f * (a[x * 2].r - a[x * 2].l + 1);
    a[x * 2 + 1].w = a[x].f * (a[x * 2 + 1].r - a[x * 2 + 1].l + 1);
    a[x].f = -1;
}
void change_interval(int k)
{
    if (a[k].l >= as && a[k].r <= bs)
    {
        a[k].f = g;
        a[k].w = (a[k].r - a[k].l + 1) * g; 
        return;
    }
    int mid = (a[k].l + a[k].r) / 2;
    if (a[k].f != -1) down(k);
    if (as <= mid)
      change_interval(k * 2);
    if (bs > mid)
      change_interval(k * 2 + 1);
    a[k].w = a[k * 2].w + a[k * 2 + 1].w;
}
void ask_interval(int k)
{
    if (a[k].l >= as && a[k].r <= bs) 
    {
        anss += a[k].w;
        return;
    }
    int mid = (a[k].l + a[k].r) / 2;
    if (a[k].f != -1) down(k);
    if (as <= mid)
      ask_interval(k * 2);
    if (bs > mid)
      ask_interval(k * 2 + 1);
}
int Si(int x)
{
    int ans = 0;
    int fs = top[x];
    while (fs)
    {
        anss = 0;
        as = id[fs];
        bs = id[x];
        ask_interval(1);
        ans += id[x] - id[fs] - anss + 1;
        change_interval(1);
        x = fa[fs];
        fs = top[x];
    }
    as = id[0];
    bs = id[x];
    anss = 0;
    ask_interval(1);
    ans += id[x] - id[0] - anss + 1;
    change_interval(1);
    return ans;
}
signed main(){
    scanf("%lld",&n);
    for (int i = 1; i < n; i++)
    {
        int x;
        scanf("%lld",&x);
        ins(x,i);
        ins(i,x);
    }
    dfs1(0);
    dfs2(0,0);
    build(1,1,n);
    scanf("%lld",&m);
    for (int i = 1; i <= m; i++)
    {
        char s[10];
        int x;
        scanf("%s%lld",s,&x);
        anss = 0;
        cnt = 0;
        if (s[0] == 'i')
        {
            g = 1;
            printf("%lld\n",Si(x));
        } else
        {
            as = id[x];
            bs = id[x] + size[x] - 1;
            ask_interval(1);
            g = 0;
            change_interval(1); 
            printf("%lld\n",anss);
        }	
    }
    return 0;
}

推荐题目

[SDOI2011]染色 DP+树剖

posted @ 2018-12-21 20:40  taoyc  阅读(166)  评论(0编辑  收藏  举报