spoj913 QTREE2 Query on a treeⅡ

题目链接:http://www.spoj.com/problems/QTREE2/en/

题目大意:

N个节点的树,边的编号为1~N-1,每条边有一个权值,要求模拟两种操作:
1:DIST a b :求点a到点b之间的距离
2:KTH a b k : 求从a出发到b遇到的第k个节点的编号


题解:

树剖的话,对于操作一就类模版询问,对于操作二就开两个数组分别记录从a起走到的点及从b起走到的点(直接跳(不会超时orz) 最后直接在数组找
倍增LCA的话就找a,b的LCA看看第k个点是在a到LCA的路径上还是LCA到b的路径上,倍增找点√

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define maxn 50100

struct node
{
	int x,y,c,next;
}a[maxn],e[maxn];int len,first[maxn];
struct trnode
{
	int l,r,lc,rc,c;
}tr[maxn];int trlen,z;
int son[maxn],fa[maxn],dep[maxn];
int ys[maxn],top[maxn],tot[maxn];
int mymax(int x,int y){return (x>y)?x:y;}
void ins(int x,int y,int c)
{
	len++;a[len].c=c;
	a[len].x=x;a[len].y=y;
	a[len].next=first[x];first[x]=len;
}
void dfs1(int x)
{
	son[x]=0;tot[x]=1;
	for (int k=first[x];k!=-1;k=a[k].next)
	{
		int y=a[k].y;
		if (y!=fa[x])
		{
			dep[y]=dep[x]+1;
			fa[y]=x;
			dfs1(y);
			if (tot[son[x]]<tot[y]) son[x]=y;
			tot[x]+=tot[y];
		}
	}
}
void dfs2(int x,int tp)
{
	ys[x]=++z;top[x]=tp;
	if (son[x]!=0) dfs2(son[x],tp);
	for (int k=first[x];k!=-1;k=a[k].next)
	{
		int y=a[k].y;
		if (y!=fa[x] && y!=son[x])
		 dfs2(y,y);
	}
}
void bt(int l,int r)
{
	++trlen;int now=trlen;
	tr[now].l=l;tr[now].r=r;
	tr[now].lc=tr[now].rc=-1;
	tr[now].c=0;
	if (l<r)
	{
		int mid=(l+r)>>1;
		tr[now].lc=trlen+1;bt(l,mid);
		tr[now].rc=trlen+1;bt(mid+1,r);
	}
}
void change(int now,int x,int k)
{
	if (tr[now].l==tr[now].r) {tr[now].c=k;return;}
	int mid=(tr[now].l+tr[now].r)>>1,lc=tr[now].lc,rc=tr[now].rc;
	if (x<=mid) change(lc,x,k);
	else change(rc,x,k);
	tr[now].c=tr[lc].c+tr[rc].c;
}
int fsum(int now,int l,int r)
{
	if (tr[now].l==l && tr[now].r==r) return tr[now].c;
	int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)>>1;
	if (r<=mid) return fsum(lc,l,r);
	else if (l>mid) return fsum(rc,l,r);
	else return fsum(lc,l,mid)+fsum(rc,mid+1,r);
}
int query(int x,int y)
{
	int ans=0,tx=top[x],ty=top[y];
	while (tx!=ty)
	{
		if (dep[tx]>dep[ty])
		{
			int tt=tx;tx=ty;ty=tt;
			tt=x;x=y;y=tt;
		}ans+=fsum(1,ys[ty],ys[y]);
		y=fa[ty];ty=top[y];
	}
	if (x==y) return ans;
	else
	{
		if (dep[x]>dep[y]) {int tt=x;x=y;y=tt;}
		return ans+fsum(1,ys[son[x]],ys[y]);
	}
}
int l1,l2,lt1[maxn],lt2[maxn];//两个记录的数组
int stp(int x,int y,int k)//就是跳啊跳orz
{
	l1=l2=0;
	while (x!=y)
	{
		if (dep[x]>dep[y])
		{
			lt1[++l1]=x;x=fa[x];
			if (k<=l1) return lt1[k];//如果还没到LCA就找到了直接返回
		}else {lt2[++l2]=y;y=fa[y];}
	}
	k-=l1;lt2[++l2]=x;
	return lt2[l2-k+1];//过了LCA的就算一下在另一条链的位置 返回
}
int main()
{
	//freopen("a.in","r",stdin);
	//freopen("a.out","w",stdout);
	int T,n,i,x,y,c,k;char s[20];
	scanf("%d",&T);
	while(T--)
	{
		scanf("%d",&n);
		len=0;memset(first,-1,sizeof(first));
		for (i=1;i<n;i++)
		{
			scanf("%d%d%d",&x,&y,&c);
			ins(x,y,c);ins(y,x,c);
			e[i].x=x;e[i].y=y;e[i].c=c;
		}
		fa[1]=0;dep[1]=0;dfs1(1);
		z=0;dfs2(1,1);
		trlen=0;bt(1,n);
		for (i=1;i<n;i++)
		{
			if (dep[e[i].x]>dep[e[i].y])
			{
				int tt=e[i].x;e[i].x=e[i].y;e[i].y=tt;
			}
			change(1,ys[e[i].y],e[i].c);
		}
		while (1)
		{
			scanf("\n%s",s);
			if (s[1]=='O') break;
			else if (s[0]=='D')
			{
				scanf("%d%d",&x,&y);
				printf("%d\n",query(x,y));
			}else if (s[0]=='K')
			{
				scanf("%d%d%d",&x,&y,&k);
				printf("%d\n",stp(x,y,k));
			}
		}
	}
	return 0;
}


posted @ 2016-09-06 16:50  OxQ  阅读(125)  评论(0编辑  收藏  举报