Count on a tree SPOJ 10628 主席树+LCA(树链剖分实现)(两种存图方式)

Count on a tree SPOJ 10628 主席树+LCA(树链剖分实现)(两种存图方式)

题外话,这是我第40篇随笔,纪念一下。<( ̄︶ ̄)↗[GO!]

题意

是说有棵树,每个节点上都有一个值,然后让你求从一个节点到另一个节点的最短路上第k小的值是多少。

解题思路

看到这个题一想以为是树链剖分+主席树,后来写着写着发现不对,因为树链剖分我们分成了一小段一小段,这些小段不能合并起来求第k小,所以这个想法不对。奈何不会做,查了查题解,需要用LCA(最近公共祖先),然后根据主席树具有区间加减的性质,我们查询一段区间的状态可以从LCA的角度去看问题,找到LCA(x, y)然后,我们只要一个LCA节点,然后求出区间X到根节点,以及Y到根节点的关系式来推这个关系,但是千万不要去减两倍LCA的关系,因为那样就会少掉一个节点了,于是,就dfs()往下建树,就是寻找到最后的答案。

\[t[t[x].l].sum + t[t[y].l].sum - t[t[lca].l].sum - t[t[gra].l].sum \]

注意:这里我求LCA的方法是用的树链剖分的方法,求LCA的方法有很多,但是我就会这一种🙃

如果没有学过树链剖分,我这里有一些学习资料推荐,点我

如果没有学过主席树,别急,我这也有好的视频和博客推荐,点我

上面两个都是我学习过程中遇到的好的博客文章的收集,节省再次查找的时间

代码实现(图用vector存版,用链式向前星版)

//vector版,方便但是稍微慢一些
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
const int maxn=1e5+100;
struct node{
	int l, r, sum;
}t[maxn*40];
vector<int> g[maxn];
vector<int> v;
int tot, root[maxn], w[maxn];

int dep[maxn], f[maxn], size[maxn], son[maxn]; 
int top[maxn];
int n, m, nn; //nn是实际去重后的个数
int read() //快读函数,这里已经实验过了,用和不用都行,用了会快一些。
{
    int f=1,x=0;
    char ss=getchar();
    while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
    while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
    return f*x;
}
int getid(int x) //离散化之后求坐标
{
	return lower_bound(v.begin() , v.end() , x)- v.begin() +1;
}
void dfs1(int u, int fa, int depth)
{
	size[u]=1;
	dep[u]=depth;
	f[u]=fa;
	int len=g[u].size();
	for(int i=0; i<len; i++)
	{
		int v=g[u][i];
		if(v==fa) 
			continue;
		dfs1(v, u, depth+1);
		size[u]+=size[v];
		if(size[v] > size[son[u]])
			son[u]=v;
	}
}
void dfs2(int u, int tp)
{
	top[u]=tp;
	if(!son[u])
		return ;
	dfs2(son[u], tp);
	int len=g[u].size();
	for(int i=0; i<len; i++)
	{
		int v=g[u][i];
		if(v==son[u] || v==f[u])
			continue;
		dfs2(v, v);
	}
}
int lca(int x,int y)
{
	int fx=top[x], fy=top[y];
	while(fx!=fy)
	{
		if(dep[fx] < dep[fy]) 
		{
			swap(x, y);
			swap(fx, fy);
		}
		x=f[fx]; //这里右边是fx,千万别写错了,我就是这犯了错,wa了几十下。。。
		fx=top[x];
	}
	return dep[x] < dep[y] ? x : y ;
}
void update(int l, int r, int pre, int &now, int pos)
{
	t[++tot]=t[pre];
	t[tot].sum++;
	now=tot;
	if(l==r) return ;
	int mid=(l+r)>>1;
	if(pos<=mid)
		update(l, mid, t[pre].l, t[now].l, pos);
	else 
		update(mid+1, r, t[pre].r, t[now].r, pos);
}
int query(int l, int r, int x, int y, int lca, int gra, int k)
{
	if(l==r)
		return l;
	int mid=(l+r)>>1;
	int sum=t[t[x].l].sum + t[t[y].l].sum - t[t[lca].l].sum - t[t[gra].l].sum ;
	if(k<=sum)
		return query(l, mid, t[x].l, t[y].l, t[lca].l, t[gra].l, k);
	else 
		return query(mid+1, r, t[x].r, t[y].r, t[lca].r, t[gra].r, k-sum); 
}
void dfs(int u)
{
	int pos=getid(w[u]);
	update(1, nn, root[f[u]], root[u], pos);
	int len=g[u].size();
	for(int i=0; i<len; i++)
	{
		int v=g[u][i];
		if(v==f[u])
			continue;
		dfs(v);
	}
}
int main()
{
	scanf("%d%d", &n, &m); 
	for(int i=1; i<=n; i++)
	{
		w[i]=read();
		v.push_back(w[i]);
	}
	sort(v.begin() , v.end() );
	v.erase( unique( v.begin() , v.end() ), v.end());
	nn=v.size();
	int x, y;
	for(int i=1; i<n; i++)
	{
		scanf("%d%d", &x, &y);
		g[x].push_back(y);
		g[y].push_back(x);
	}
	dfs1(1, 0, 1);
	dfs2(1, 1);
	dfs(1);
	int k, la;
	for(int i=1; i<=m; i++)
	{
		scanf("%d%d%d", &x, &y, &k);
		la=lca(x, y);
		printf("%d\n", v[ query(1, nn, root[x], root[y], root[la], root[ f[la] ], k) -1 ] );
	}
	return 0;
}
// 链式向前星版存图
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<iostream>
using namespace std;
const int maxn=1e5+10000;
struct node{
	int l, r, sum;
}t[maxn*40];
int tot; //tot是主席树点的个数 

struct edge{
	int to, next;
}e[maxn<<1];
int head[maxn], cnt; //cnt是边的个数 

int root[maxn], w[maxn], id[maxn];

int dep[maxn], f[maxn], size[maxn], son[maxn];
int top[maxn];
int n, m, nn;
inline int read()
{
    int f=1,x=0;
    char ss=getchar();
    while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
    while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
    return f*x;
}
inline void add(int u, int v)
{
	e[++cnt].next=head[u];
	e[cnt].to=v;
	head[u]=cnt;
}

int getid(int x)
{
	return (lower_bound(&id[1] , &id[nn+1] , x)-id);
}

void update(int l, int r, int pre, int &now, int pos)
{
	t[++tot]=t[pre];
	t[tot].sum++;
	now=tot;
	if(l==r) return ;
	int mid=(l+r)>>1;
	if(pos<=mid)
		update(l, mid, t[pre].l, t[now].l, pos);
	else
		update(mid+1, r, t[pre].r, t[now].r, pos);
}

int query(int l, int r, int x, int y, int lca, int gra, int k)
{
	if(l==r)
		return l;
	int mid=(l+r)>>1;
	int sum=t[t[x].l].sum + t[t[y].l].sum - t[t[lca].l].sum - t[t[gra].l].sum ;
	if(k<=sum)
		return query(l, mid, t[x].l, t[y].l, t[lca].l, t[gra].l, k);
	else
		return query(mid+1, r, t[x].r, t[y].r, t[lca].r, t[gra].r, k-sum);
}
void dfs1(int u, int fa, int depth)
{
	size[u]=1;
	dep[u]=depth;
	f[u]=fa;
	for(int i=head[u]; i; i=e[i].next)
	{
		int v=e[i].to;
		if(v==fa)
			continue;
		dfs1(v, u, depth+1);
		size[u]+=size[v];
		if(size[v] > size[son[u]])
			son[u]=v;
	}
}
void dfs2(int u, int tp)
{
	top[u]=tp;
	if(!son[u])
		return ;
	dfs2(son[u], tp);
	for(int i=head[u]; i; i=e[i].next)
	{
		int v=e[i].to;
		if(v==son[u] || v==f[u])
			continue;
		dfs2(v, v);
	}
}

int lca(int x,int y)
{
	int fx=top[x], fy=top[y];
	while(fx!=fy)
	{
		if(dep[fx] < dep[fy])
		{
			swap(x, y);
			swap(fx, fy);
		}
		x=f[fx];
		fx=top[x];
	}
	return dep[x] < dep[y] ? x : y ;
}
void dfs(int u)
{
	int pos=getid(w[u]);
	update(1, nn, root[f[u]], root[u], pos);
	for(int i=head[u]; i; i=e[i].next)
	{
		int v=e[i].to;
		if(v==f[u])
			continue;
		dfs(v);
	}
}
int main()
{
	scanf("%d%d", &n, &m);
	for(int i=1; i<=n; i++)
	{
		scanf("%d", &w[i]);
		id[i]=w[i];
	}
	sort(&id[1], &id[n+1] );
	nn = (unique( &id[1], &id[n+1])-id-1) ;
	int x, y;
	for(int i=1; i<n; i++)
	{
		scanf("%d%d", &x, &y);
		add(x, y);
		add(y, x);
	}
	dfs1(1, 0, 1);
	dfs2(1, 1);
	dfs(1);
	int k, la;
	for(int i=1; i<=m; i++)
	{
		scanf("%d%d%d", &x, &y, &k);
		la=lca(x, y);
		printf("%d\n", id[ query(1, nn, root[x], root[y], root[la], root[ f[la] ], k) ] );
	}
	return 0;
}

posted @ 2019-08-29 22:46  ALKING1001  阅读(176)  评论(0编辑  收藏  举报