树链剖分学习笔记(未完)

思想

树链剖分用于将树分割成若干条链的形式,以维护树上路径的信息。
具体来说,将整棵树剖分为若干条链,使它组合成线性结构,然后用其他的数据结构维护信息。

重链剖分

原理

首先先定义一些概念

概念 定义
重儿子 每个点的子树中,子树的节点数和最大的子节点
轻儿子 除重儿子外的其他子节点
重边 每个节点与其重儿子间的边
轻边 每个节点与其轻儿子间的边
重链 重边连成的链 (一个点也可以看作是重链)
轻链 轻边连成的链

因此我们可以将一颗树上的所有节点划分到若干条重链上,如图(图源:OiWiki)
image
如右图所示,整棵树被分为一条一条的重链,我们可以在dfs预处理的时候优先处理重儿子,这样可以保证每一条重链上的点的dfs序是连续的,这样就可以把树上的问题转换成区间问题,利用其他处理区间的数据结构来处理树上问题

实现

预处理

int dep[N],fa[N],sz[N],son[N],top[N];
int id[N],nw[N],timestamp;
void dfs1(int u,int father)
{
    dep[u] = dep[father] + 1,sz[u] = 1,fa[u] = father;
    for(int i = h[u];~ i;i = ne[i])
    {
        int j = e[i];
        if(j == fa[u]) continue;
        dfs1(j,u);
        sz[u] += sz[j];
        if(sz[son[u]] < sz[j]) son[u] = j;
    }

}

void dfs2(int u,int t)
{
    id[u] = ++ timestamp;
    nw[id[u]] = w[u];
    top[u] = t;
    if(!son[u]) return ;
    dfs2(son[u],t);
    for(int i = h[u];~ i;i = ne[i])
    {
        int j = e[i];
        if(j == fa[u] || j == son[u]) continue;
        dfs2(j,j);
    }
}

操作链

主要思想:每次操作两点中重链顶点深度较深的一个,然后给他跳到上面的链上去,类似于最近公共祖先暴力跳的步骤,直到两点在一个链上,最后处理两点间的节点

void modify_path(int u,int v,int k)
{
    while(top[u] != top[v])
    {
        if(dep[top[u]] < dep[top[v]]) swap(u,v);
        modify(1,id[top[u]],id[u],k);
        u = fa[top[u]];
    }

    if(dep[u] < dep[v]) swap(u,v);
    modify(1,id[v],id[u],k);
}

LL query_path(int u,int v)
{
    LL res = 0;
    while(top[u] != top[v])
    {
        if(dep[top[u]] < dep[top[v]]) swap(u,v);
        res += query(1,id[top[u]],id[u]);
        u = fa[top[u]];
    }
    if(dep[u] < dep[v]) swap(u,v);
    res += query(1,id[v],id[u]);
    return res;

}

例题

Beard Graph

题意:
给定一棵 n 个节点的树,初始所有边都是黑边。
有 m 个操作:
1 u:把第 u 条边改成黑边。
2 u:把第 u 条边改成白边。
3 u v:若 u 号节点和 v 号节点间存在白边,输出 -1,否则输出 u 号节点和 v 号节点间的黑边数。
分析 :这里涉及一个小tips,题目给的是边,可我们是对点进行剖分,这里可以将每条边的边权设给它连接的两个点之间深度较深的点
将树进行树链剖分,每条边的边权初始为0代表黑边,拿线段树维护即可
对于操作1、2
线段树单点修改即可
对于操作3
进行跳边查询,看有没有白边
这里注意最后处理同一条链的一段时,深度较浅的那个点不考虑,因为我们把每条边赋给了深度较深的点,这个点所代表的表不在我们要求的范围内
ac代码


int n,m,k,t;
int h[N],e[N << 1],ne[N << 1],idx;
int sz[N],fa[N],dep[N],son[N];
int id[N],top[N],nw[N],w[N],x[N],timestamp;
struct node
{
	int l,r;
	int sum ;
}tr[N << 2];

void pushup(int u)
{
	tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void build(int u,int l,int r)
{
	if(l == r)
	{
		tr[u] = {l,r,nw[r]};
		return ;
	}
	tr[u] = {l,r};
	int mid = l + r >> 1;
	build(u << 1,l,mid), build(u << 1 | 1,mid + 1,r);
}

void modify(int u,int x,int v)
{
	if(tr[u].l == tr[u].r)
	{
		tr[u].sum = v;
		return ;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	if(x <= mid) modify(u << 1,x,v);
	else modify(u << 1 | 1,x,v);
	pushup(u);
}

int query(int u,int l,int r)
{
	if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
	int mid = tr[u].l + tr[u].r >> 1;
	int res = 0;
	if(l <= mid) res += query(u << 1,l,r);
	if(r > mid) res += query(u << 1 | 1,l,r);
	return res;
}

void add(int a,int b,int c)
{
	e[idx] = b,x[idx] = c,ne[idx] = h[a],h[a] = idx ++;
}

void dfs1(int u,int father)
{
	sz[u] = 1,fa[u] = father,dep[u] = dep[fa[u]] + 1;
	for(int i = h[u];~ i;i = ne[i])
	{
		int j = e[i];
		if(j == fa[u]) continue;
		w[j] = x[i];
		dfs1(j,u);
		sz[u] += sz[j];
		if(sz[son[u]] < sz[j]) son[u] = j;
	}

}

void dfs2(int u,int t)
{
	id[u] = ++ timestamp;
	nw[timestamp] = w[u];
	top[u] = t;
	if(!son[u]) return ;
	dfs2(son[u],t);
	for(int i = h[u];~ i;i = ne[i])
	{
		int j = e[i];
		if(j == fa[u] || j == son[u]) continue;
		dfs2(j,j);
	}
}
int ans;
int query_path(int u,int v)
{
	int res = 0;
	while(top[u] != top[v])
	{
		if(dep[top[u]] < dep[top[v]]) swap(u,v);
		res += query(1,id[top[u]],id[u]);
		ans += id[u] - id[top[u]] + 1;
		u = fa[top[u]];
	}
	if(dep[u] < dep[v]) swap(u,v); 
	res += query(1,id[son[v]],id[u]);
	ans += id[u] - id[v] + 1;
	return res;
}

int main()
{
	ios;
	cin >> n;	
	for(int i = 1;i <= n;i ++) h[i] = -1;
	vector<PII> edges(n);
	for(int i = 1;i < n;i ++)
	{
		int a,b;
		cin >> a >> b;
		edges[i] = {a,b};
		add(a,b,0), add(b,a,0);
	}

	dfs1(1,0);
	dfs2(1,1);
	build(1,1,n);
	cin >> m;
	while(m --)
	{
		int op,u,v;
		cin >> op ;
		if(op == 1)
		{
			cin >> u ;
			if(dep[edges[u].x] > dep[edges[u].y]) modify(1,id[edges[u].x],0);
			else modify(1,id[edges[u].y],0);
		}
		else if(op == 2)
		{
			cin >> u ;
			if(dep[edges[u].x] > dep[edges[u].y]) modify(1,id[edges[u].x],1);
			else modify(1,id[edges[u].y],1);
		}
		else
		{
			cin >> u >> v;
			if(u == v) 
			{
				cout << 0 << endl;
				continue;
			}
			ans = 0;
			int t = query_path(u,v);
			if(t) cout << -1 << endl;
			else
			{
				cout << max(ans - 1,0) << endl;
			}
		}
	}
	return 0;
}

Water Tree

题意:
给出一棵以 1 为根节点的 n 个节点的有根树。每个点有一个权值,初始为 0。
m 次操作。操作有 3 种:
将点 u 和其子树上的所有节点的权值改为 1。
将点 u 到 1 的路径上的所有节点的权值改为 0。
询问点 u 的权值。

分析:
对于第一种操作,因为树中某一个节点u的子树其dfs序都在[id[u], id[u] + sz[u] - 1]内,所以直接线段树需改即可
对于第二种操作,进行跳链操作修改
对于第三种操作,线段树单点查询
代码


int n,m,k,t;
int h[N],e[N << 1],ne[N << 1],idx;
int sz[N],fa[N],dep[N],son[N];
int id[N],top[N],timestamp;
struct node
{
	int l,r;
	int sum,val;
}tr[N << 2];

void pushup(int u)
{
	tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(node & u,LL v)
{
	u.sum = (u.r - u.l + 1) * v;
	u.val = v;
}

void pushdown(int u)
{
	if(tr[u].val != -1) 
	{
		pushdown(tr[u << 1],tr[u].val);
		pushdown(tr[u << 1 | 1],tr[u].val);
		tr[u].val = -1;
	}
}

void build(int u,int l,int r)
{
	tr[u] = {l,r,0,-1};
	if(l == r) return ;
	int mid =  l + r >> 1;
	build(u << 1,l,mid), build(u << 1 | 1,mid + 1,r);
}

void modify(int u,int l,int r,int v)
{
	if(tr[u].l >= l && tr[u].r <= r) 
	{
		pushdown(tr[u],v);
		return;
	}

	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if(l <= mid) modify(u << 1,l,r,v);
	if(r > mid) modify(u << 1 | 1,l,r,v);
	pushup(u);
}

LL query(int u,int x)
{
	if(tr[u].l == x && tr[u].r == x) return tr[u].sum;
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if(x <= mid) return query(u << 1,x);
	else return query(u << 1 | 1,x);
}

void add(int a,int b)
{
	e[idx] = b,ne[idx] = h[a],h[a] = idx ++;
}

void dfs1(int u,int father)
{
	sz[u] = 1,fa[u] = father,dep[u] = dep[fa[u]] + 1;
	for(int i = h[u];~ i;i = ne[i])
	{
		int j = e[i];
		if(j == fa[u]) continue;
		dfs1(j,u);
		sz[u] += sz[j];
		if(sz[j] > sz[son[u]]) son[u] = j;
	}
}

void dfs2(int u,int t)
{
	id[u] = ++ timestamp;
	top[u] = t;
	if(!son[u]) return;
	dfs2(son[u],t);

	for(int i = h[u];~ i;i = ne[i])
	{
		int j = e[i];
		if(j == fa[u] || j == son[u]) continue;
		dfs2(j,j);
	}
}

void modify_tree(int u)
{
	modify(1,id[u],id[u] + sz[u] - 1,1);
}

void modify_path(int u)
{
	while(top[u] != 1)
	{
		modify(1,id[top[u]],id[u],0);
		u = fa[top[u]];
	}
	modify(1,id[1],id[u],0);
}

int main()
{
	ios;
	cin >> n;
	for(int i = 1;i <= n;i ++) h[i] = -1;
	for(int i = 1;i < n;i ++)
	{
		int a,b;
		cin >> a >> b;
		add(a,b), add(b,a);
	}

	dfs1(1,0);
	dfs2(1,1);

	build(1,1,n);
	cin >> m;
	while(m --)
	{
		int op,u;
		cin >> op >> u;
		if(op == 1) modify_tree(u);
		else if(op == 2) modify_path(u);
		else cout << query(1,id[u]) << endl;
	}

	return 0;
}

Duff in the Army

题意:
一个国家由一颗树的形式展示,每一个节点都为一个城市,有m个人住在这个国家中,每个人都有一个编号,现在问从城市u到城市v的前a小的编号(1 <= a <= 10)

分析:
这道题难点不在树剖上,关键在于每个节点要存很多人的编号,关键在于a最多是10,所以我们在每个节点最多存10个就好,在pushup时写一个两路归并就好

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <queue>
#include <map>
#include <vector>
#include <stack>
#include <set>
#include <sstream>
#include <fstream> 
#include <cmath>
#include <iomanip>
#include <bitset>
#include <unordered_map>
#include <unordered_set>
#include <random>
//#pragma GCC optimize(3,"Ofast","inline")
#define x first
#define y second
#define ios ios::sync_with_stdio(false),cin.tie(0);
#define endl '\n'
#define pb push_back
#define all(x) x.begin(),x.end()
#define all1(x) x.begin()+1,x.end()
 
using namespace std;
 
typedef unsigned long long uLL;
typedef long long LL;
typedef pair<int,int> PII;
 
const int N = 200010,M = 10010,INF = 0x3f3f3f3f,mod = 998244353;
const double DNF = 0x7f7f7f7f7f7f7f7f,pi = acos(-1.0),eps = 1e-6;
const long long LNF = 0x3f3f3f3f3f3f3f3f;
 
int n,m,k,t;
int h[N],e[N << 1],ne[N << 1],idx;
int sz[N],fa[N],dep[N],son[N];
int id[N],top[N],nw[N],w[N],x[N << 1],timestamp;
vector<int> c[N];
struct node
{
	int l,r;
	int a[11],tt;	
}tr[N << 2];

void pushup(int u)
{
	int i = 1,j = 1,& k = tr[u].tt;
	k = 0;
	while(i <= tr[u << 1].tt && j <= tr[u << 1 | 1].tt)
	{
		if(tr[u << 1].a[i] < tr[u << 1 | 1].a[j]) tr[u].a[++ k] = tr[u << 1].a[i ++];
		else tr[u].a[++ k] = tr[u << 1 | 1].a[j ++];
		if(k == 10) return ;
	}
	while(i <= tr[u << 1].tt)
	{
		tr[u].a[++ k] = tr[u << 1].a[i ++];
		if(k == 10) return ;
	}
	while(j <= tr[u << 1 | 1].tt)
	{
		tr[u].a[++ k] = tr[u << 1 | 1].a[j ++];
		if(k == 10) return ;
	}

}

void build(int u,int l,int r)
{
	if(l == r)
	{
		tr[u] = {l,r};
		int & k =  tr[u].tt,i = 0;
		k = 0;
		while(k < 10 && i < c[w[r]].size()) tr[u].a[++ k] = c[w[r]][i ++];
		return ;
	}
	
	tr[u] = {l,r};
	int mid = l +r >> 1;
	build(u << 1,l,mid), build(u << 1 | 1,mid + 1,r);
	pushup(u);
}
int a[N],tt;
void query(int u,int l,int r)
{
	if(tr[u].l >= l && tr[u].r <= r) 
	{
		for(int i = 1;i <= tr[u].tt;i ++) a[++ tt] = tr[u].a[i];
		return ;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	if(l <= mid) query(u << 1,l,r);
	if(r > mid) query(u << 1 | 1,l,r);
}

void add(int a,int b)
{
	e[idx] = b,ne[idx] = h[a],h[a] = idx ++;
}

void dfs1(int u,int father)
{
	sz[u] = 1,fa[u] = father,dep[u] = dep[fa[u]] + 1;
	for(int i = h[u];~ i;i = ne[i])
	{
		int j = e[i];
		if(j == fa[u]) continue;
		dfs1(j,u);;
		sz[u] += sz[j];
		if(sz[j] > sz[son[u]]) son[u] = j;
	}
}

void dfs2(int u,int t)
{
	id[u] = ++ timestamp;
	w[timestamp] = u;
	top[u] = t;
	if(!son[u]) return;
	dfs2(son[u],t);
	for(int i = h[u];~ i;i = ne[i])
	{
		int j = e[i];
		if(j == fa[u] || j == son[u]) continue;
		dfs2(j,j);
	}
}

void query_path(int u,int v)
{
	while(top[u] != top[v])
	{
		if(dep[top[u]] < dep[top[v]]) swap(u,v);
		query(1,id[top[u]],id[u]);
		u = fa[top[u]];
	}
	if(dep[u] < dep[v]) swap(u,v);
	query(1,id[v],id[u]);
}

int main()
{
	ios;
	cin >> n >> m >> t;
	for(int i = 1;i <= n;i ++) h[i] = -1;
	for(int i = 1;i < n;i ++)
	{
		int a,b;
		cin >> a >> b;
		add(a,b), add(b,a);
	}
	for(int i = 1;i <= m;i ++)
	{
		int x;
		cin >> x;
		c[x].pb(i);
	}

	dfs1(1,0);
	dfs2(1,1);
	build(1,1,n);

	while(t --)
	{
		int u,v,k;
		cin >> u >> v >> k;
		tt = 0;
		query_path(u,v);
		sort(a + 1,a + 1 + tt);
		cout << min(k,tt) << ' ';
		for(int i = 1;i <= min(k,tt);i ++) cout << a[i] << ' ';
		cout << endl;
	}
	return 0;
}
posted @ 2022-10-03 17:45  notyour_young  阅读(16)  评论(0编辑  收藏  举报