【CF504E】Misha and LCP on Tree

题目

题目链接:https://codeforces.ml/problemset/problem/504/E
给定一棵 \(n\) 个节点的树,每个节点有一个小写字母。
\(m\) 组询问,每组询问为树上 \(a \to b\)\(c \to d\) 组成的字符串的最长公共前缀。
\(n \le 3 \times 10^5\)\(m \le 10^6\)

思路

直接在树上显然是没办法做的。我们考虑将树上问题转化为序列上的问题。
为了把树上的链扔到序列上,我们不难想到重链剖分。这样每一条链都被我们转化为了序列上 \(O(\log n)\) 个区间。
我们把字符串重新排序,让树上节点 \(i\) 所对应的字符,在新字符串重排在树剖后 \(i\) 的编号的位置上。然后求出后缀数组,可以 ST 表 \(O(1)\) 求出两个后缀的 LCP。
对于每一个询问,我们分别求出两条链在新的字符串中对应的区间。由于存在从下往上的路径,和我们的编号恰好相反,所以我们需要先把字符串复制一份并翻转放在最后面再跑 SA。
然后我们只需要对 \(O(\log n)\) 个区间求 LCP。当某一位已经不匹配时就退出即可。
时间复杂度 \(O(n\log n)\)

代码

#include <bits/stdc++.h>
using namespace std;

const int N=600010,LG=20; 
int q11[N],q12[N],q21[N],q22[N];
int head[N],top[N],son[N],siz[N],id[N],rk[N],dep[N],fa[N];
int c[N],x[N],y[N],sa[N],lg[N],height[N],st[N][LG+1];
int n,m,tot,len1,len2;
char s[N],t[N];

struct edge
{
	int next,to;
}e[N];

void add(int from,int to)
{
	e[++tot]=(edge){head[from],to};
	head[from]=tot;
}

void dfs1(int x,int f)
{
	dep[x]=dep[f]+1; fa[x]=f; siz[x]=1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=f)
		{
			dfs1(v,x);
			siz[x]+=siz[v];
			if (siz[v]>siz[son[x]]) son[x]=v;
		}
	}
}

void dfs2(int x,int tp)
{
	top[x]=tp; id[x]=++tot; rk[tot]=x;
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa[x] && v!=son[x]) dfs2(v,v);
	}
}

void SA(int n,int m)
{
	for (int i=1;i<=n;i++) x[i]=s[i],c[x[i]]++;
	for (int i=1;i<=m;i++) c[i]+=c[i-1];
	for (int i=n;i>=1;i--) sa[c[x[i]]--]=i;
	for (int k=1;k<=n;k<<=1)
	{
		int num=0;
		for (int i=n-k+1;i<=n;i++) y[++num]=i;
		for (int i=1;i<=n;i++) if (sa[i]>k) y[++num]=sa[i]-k;
		for (int i=1;i<=m;i++) c[i]=0;
		for (int i=1;i<=n;i++) c[x[i]]++;
		for (int i=1;i<=m;i++) c[i]+=c[i-1];
		for (int i=n;i>=1;i--) sa[c[x[y[i]]]--]=y[i],y[i]=0;
		swap(x,y);
		num=x[sa[1]]=1;
		for (int i=2;i<=n;i++)
			x[sa[i]]=(y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
		m=num;
		if (n==m) return;
	}
}

void geth(int n)
{
	for (int i=1;i<=n;i++) rk[sa[i]]=i;
	for (int i=1,k=0;i<=n;i++)
	{
		if (k) k--;
		int j=sa[rk[i]-1];
		while (s[i+k]==s[j+k]) k++;
		height[rk[i]]=k;
	}
}

void getst(int n)
{
	for (int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
	for (int i=1;i<=n;i++) st[i][0]=height[i];
	for (int i=n;i>=1;i--)
		for (int j=1;i+(1<<j)-1<=n;j++)
			st[i][j]=min(st[i][j-1],st[i+(1<<j-1)][j-1]);
}

int lca(int x,int y)
{
	while (top[x]!=top[y])
	{
		if (dep[top[x]]<dep[top[y]]) swap(x,y);
		x=fa[top[x]];
	}
	if (dep[x]<dep[y]) swap(x,y);
	return y;
}

int findson(int y,int x)
{
	int last=-1;
	for (;top[x]!=top[y];x=fa[top[x]])
		last=top[x];
	if (x==y) return last;
	return son[y];
}

void find(int *q1,int *q2,int &cnt,int x,int y)
{
	if (y==-1) return;
	for (;top[x]!=top[y];x=fa[top[x]])
		q1[++cnt]=id[top[x]],q2[cnt]=id[x];
	q1[++cnt]=id[y]; q2[cnt]=id[x];
}

int lcp(int i,int j)
{
	if (i==j) return 1e9;
	if (i>j) swap(i,j);
	int k=lg[j-i];
	return min(st[i+1][k],st[j-(1<<k)+1][k]);
}

int solve()
{
	int ans=0;
	for (int i=1,j=1;i<=len1 && j<=len2;)
	{
		int len=lcp(rk[q11[i]],rk[q21[j]]);
		if (q12[i]-q11[i]<q22[j]-q21[j])
		{
			if (q12[i]-q11[i]+1>len) return ans+len;
			ans+=q12[i]-q11[i]+1;
			q21[j]+=q12[i]-q11[i]+1; i++;
		}
		else if (q12[i]-q11[i]>q22[j]-q21[j])
		{
			if (q22[j]-q21[j]+1>len) return ans+len;
			ans+=q22[j]-q21[j]+1;
			q11[i]+=q22[j]-q21[j]+1; j++;
		}
		else
		{
			if (q12[i]-q11[i]+1>len) return ans+len;
			ans+=q12[i]-q11[i]+1;
			i++; j++;
		}
	}
	return ans;
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%s",&n,t+1);
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	tot=0;
	dfs1(1,0); dfs2(1,1);
	for (int i=1;i<=n;i++)
		s[i]=s[2*n-i+1]=t[rk[i]];
	SA(2*n,'z'); geth(2*n); getst(2*n);
	scanf("%d",&m);
	while (m--)
	{
		int u1,v1,u2,v2,p1,p2;
		scanf("%d%d%d%d",&u1,&v1,&u2,&v2);
		p1=lca(u1,v1); p2=lca(u2,v2);
		len1=0;
		find(q12,q11,len1,u1,findson(p1,u1));
		for (int i=len1;i>=1;i--)
		{
			q11[i]=2*n-q11[i]+1;
			q12[i]=2*n-q12[i]+1;
		}
		int tmp=len1;
		find(q11,q12,len1,v1,p1);
		reverse(q11+tmp+1,q11+len1+1);
		reverse(q12+tmp+1,q12+len1+1);
		len2=0;
		find(q22,q21,len2,u2,findson(p2,u2));
		for (int i=len2;i>=1;i--)
		{
			q21[i]=2*n-q21[i]+1;
			q22[i]=2*n-q22[i]+1;
		}
		tmp=len2;
		find(q21,q22,len2,v2,p2);
		reverse(q21+tmp+1,q21+len2+1);
		reverse(q22+tmp+1,q22+len2+1);
		printf("%d\n",solve());
	}
	return 0;
}
posted @ 2021-01-14 16:44  stoorz  阅读(118)  评论(0编辑  收藏  举报