【CF504E】Misha and LCP on Tree
题目
题目链接:https://codeforces.ml/problemset/problem/504/E
给定一棵 \(n\) 个节点的树,每个节点有一个小写字母。
有 \(m\) 组询问,每组询问为树上 \(a \to b\) 和 \(c \to d\) 组成的字符串的最长公共前缀。
\(n \le 3 \times 10^5\),\(m \le 10^6\)。
思路
直接在树上显然是没办法做的。我们考虑将树上问题转化为序列上的问题。
为了把树上的链扔到序列上,我们不难想到重链剖分。这样每一条链都被我们转化为了序列上 \(O(\log n)\) 个区间。
我们把字符串重新排序,让树上节点 \(i\) 所对应的字符,在新字符串重排在树剖后 \(i\) 的编号的位置上。然后求出后缀数组,可以 ST 表 \(O(1)\) 求出两个后缀的 LCP。
对于每一个询问,我们分别求出两条链在新的字符串中对应的区间。由于存在从下往上的路径,和我们的编号恰好相反,所以我们需要先把字符串复制一份并翻转放在最后面再跑 SA。
然后我们只需要对 \(O(\log n)\) 个区间求 LCP。当某一位已经不匹配时就退出即可。
时间复杂度 \(O(n\log n)\)。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=600010,LG=20;
int q11[N],q12[N],q21[N],q22[N];
int head[N],top[N],son[N],siz[N],id[N],rk[N],dep[N],fa[N];
int c[N],x[N],y[N],sa[N],lg[N],height[N],st[N][LG+1];
int n,m,tot,len1,len2;
char s[N],t[N];
struct edge
{
int next,to;
}e[N];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void dfs1(int x,int f)
{
dep[x]=dep[f]+1; fa[x]=f; siz[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=f)
{
dfs1(v,x);
siz[x]+=siz[v];
if (siz[v]>siz[son[x]]) son[x]=v;
}
}
}
void dfs2(int x,int tp)
{
top[x]=tp; id[x]=++tot; rk[tot]=x;
if (son[x]) dfs2(son[x],tp);
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa[x] && v!=son[x]) dfs2(v,v);
}
}
void SA(int n,int m)
{
for (int i=1;i<=n;i++) x[i]=s[i],c[x[i]]++;
for (int i=1;i<=m;i++) c[i]+=c[i-1];
for (int i=n;i>=1;i--) sa[c[x[i]]--]=i;
for (int k=1;k<=n;k<<=1)
{
int num=0;
for (int i=n-k+1;i<=n;i++) y[++num]=i;
for (int i=1;i<=n;i++) if (sa[i]>k) y[++num]=sa[i]-k;
for (int i=1;i<=m;i++) c[i]=0;
for (int i=1;i<=n;i++) c[x[i]]++;
for (int i=1;i<=m;i++) c[i]+=c[i-1];
for (int i=n;i>=1;i--) sa[c[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);
num=x[sa[1]]=1;
for (int i=2;i<=n;i++)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
m=num;
if (n==m) return;
}
}
void geth(int n)
{
for (int i=1;i<=n;i++) rk[sa[i]]=i;
for (int i=1,k=0;i<=n;i++)
{
if (k) k--;
int j=sa[rk[i]-1];
while (s[i+k]==s[j+k]) k++;
height[rk[i]]=k;
}
}
void getst(int n)
{
for (int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
for (int i=1;i<=n;i++) st[i][0]=height[i];
for (int i=n;i>=1;i--)
for (int j=1;i+(1<<j)-1<=n;j++)
st[i][j]=min(st[i][j-1],st[i+(1<<j-1)][j-1]);
}
int lca(int x,int y)
{
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if (dep[x]<dep[y]) swap(x,y);
return y;
}
int findson(int y,int x)
{
int last=-1;
for (;top[x]!=top[y];x=fa[top[x]])
last=top[x];
if (x==y) return last;
return son[y];
}
void find(int *q1,int *q2,int &cnt,int x,int y)
{
if (y==-1) return;
for (;top[x]!=top[y];x=fa[top[x]])
q1[++cnt]=id[top[x]],q2[cnt]=id[x];
q1[++cnt]=id[y]; q2[cnt]=id[x];
}
int lcp(int i,int j)
{
if (i==j) return 1e9;
if (i>j) swap(i,j);
int k=lg[j-i];
return min(st[i+1][k],st[j-(1<<k)+1][k]);
}
int solve()
{
int ans=0;
for (int i=1,j=1;i<=len1 && j<=len2;)
{
int len=lcp(rk[q11[i]],rk[q21[j]]);
if (q12[i]-q11[i]<q22[j]-q21[j])
{
if (q12[i]-q11[i]+1>len) return ans+len;
ans+=q12[i]-q11[i]+1;
q21[j]+=q12[i]-q11[i]+1; i++;
}
else if (q12[i]-q11[i]>q22[j]-q21[j])
{
if (q22[j]-q21[j]+1>len) return ans+len;
ans+=q22[j]-q21[j]+1;
q11[i]+=q22[j]-q21[j]+1; j++;
}
else
{
if (q12[i]-q11[i]+1>len) return ans+len;
ans+=q12[i]-q11[i]+1;
i++; j++;
}
}
return ans;
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%s",&n,t+1);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
tot=0;
dfs1(1,0); dfs2(1,1);
for (int i=1;i<=n;i++)
s[i]=s[2*n-i+1]=t[rk[i]];
SA(2*n,'z'); geth(2*n); getst(2*n);
scanf("%d",&m);
while (m--)
{
int u1,v1,u2,v2,p1,p2;
scanf("%d%d%d%d",&u1,&v1,&u2,&v2);
p1=lca(u1,v1); p2=lca(u2,v2);
len1=0;
find(q12,q11,len1,u1,findson(p1,u1));
for (int i=len1;i>=1;i--)
{
q11[i]=2*n-q11[i]+1;
q12[i]=2*n-q12[i]+1;
}
int tmp=len1;
find(q11,q12,len1,v1,p1);
reverse(q11+tmp+1,q11+len1+1);
reverse(q12+tmp+1,q12+len1+1);
len2=0;
find(q22,q21,len2,u2,findson(p2,u2));
for (int i=len2;i>=1;i--)
{
q21[i]=2*n-q21[i]+1;
q22[i]=2*n-q22[i]+1;
}
tmp=len2;
find(q21,q22,len2,v2,p2);
reverse(q21+tmp+1,q21+len2+1);
reverse(q22+tmp+1,q22+len2+1);
printf("%d\n",solve());
}
return 0;
}