【YbtOJ#532】往事之树
题目
题目链接:https://www.ybtoj.com.cn/contest/117/problem/3
\(n\leq 2\times 10^5\)。
思路
不难发现两个串 \(r(x),r(y)\) 的 LCS 就是它们 LCA 的深度。考虑枚举 LCA,然后求子树内所有字符串的 LCP 最大值。
发现题目给出的是一棵 Trie,我们可以直接离线构造广义 SAM。那么此时树上两个点 \(x,y\) 的 LCP 长度就是他们在 parent 树上的 LCA 的 \(\text{len}\)。
但是我们不能依次枚举子树内的两个点,但是我们发现 parent 树上儿子节点的 \(\text{len}\) 一定大于它父亲的 \(\text{len}\),所以我们没有必要求 \(O(n^2)\) 个点对的 LCA,只需要把他们按照 parent 树上 dfs 序相邻的计算一下就可以了。
此时我们依然需要维护一个数据结构支持维护子树内的信息,并且支持往父节点合并。考虑权值线段树,线段树一个叶子 \([i,i]\) 表示 parent 树上 dfs 序为 \(i\) 的点。如果这个点在当前子树中就为 \(1\),否则为 \(0\)。
然后权值线段树上维护区间 dfs 序相邻的点的 LCP 最大值,以及区间最左最右的点。两个区间 pushup 时可能产生的贡献只有区间临界点左右的一对点。可以 \(O(\log n)\) pushup(如果用 ST 表预处理 LCA 就可以做到 \(O(1)\) pushup,这样总复杂度只有一个 \(\log\))。
然后往上线段树合并即可。
时间复杂度 \(O(n\log^2 n)\)。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=400010,LG=20,MAXN=N*LG;
int n,ans,tot,a[N],dep[N],rt[N],last[N],head[N];
struct edge
{
int next,to;
}e[N];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
struct SAM
{
int tot,fa[N],len[N],dfn[N],rk[N],dep[N],pa[N][LG+1];
map<int,int> ch[N];
vector<int> e[N];
SAM() { tot=1; }
int ins(int last,int c)
{
int p=last,np=++tot;
len[np]=len[p]+1;
for (;!ch[p][c];p=fa[p]) ch[p][c]=np;
if (!p) fa[np]=1;
else
{
int q=ch[p][c];
if (len[q]==len[p]+1) fa[np]=q;
else
{
int nq=++tot;
fa[nq]=fa[q]; len[nq]=len[p]+1; ch[nq]=ch[q];
fa[q]=fa[np]=nq;
for (;ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
}
}
return np;
}
void adde()
{
for (int i=1;i<=tot;i++)
if (fa[i]) e[fa[i]].push_back(i);
}
void dfs(int x)
{
dfn[x]=++tot; rk[tot]=x;
dep[x]=dep[fa[x]]+1; pa[x][0]=fa[x];
for (int i=1;i<=LG;i++)
pa[x][i]=pa[pa[x][i-1]][i-1];
for (int i=0;i<e[x].size();i++)
dfs(e[x][i]);
}
int lca(int x,int y)
{
x=rk[x]; y=rk[y];
if (dep[x]<dep[y]) swap(x,y);
for (int i=LG;i>=0;i--)
if (dep[pa[x][i]]>=dep[y]) x=pa[x][i];
if (x==y) return x;
for (int i=LG;i>=0;i--)
if (pa[x][i]!=pa[y][i]) x=pa[x][i],y=pa[y][i];
return pa[x][0];
}
}sam;
struct SegTree
{
int tot,lc[MAXN],rc[MAXN],res[MAXN],L[MAXN],R[MAXN];
void pushup(int x)
{
L[x]=L[lc[x]]?L[lc[x]]:L[rc[x]];
R[x]=R[rc[x]]?R[rc[x]]:R[lc[x]];
res[x]=max(res[lc[x]],res[rc[x]]);
if (R[lc[x]] && L[rc[x]])
res[x]=max(res[x],sam.len[sam.lca(R[lc[x]],L[rc[x]])]);
}
int update(int x,int l,int r,int k)
{
if (!x) x=++tot;
if (l==r)
{
L[x]=R[x]=l;
return x;
}
int mid=(l+r)>>1;
if (k<=mid) lc[x]=update(lc[x],l,mid,k);
else rc[x]=update(rc[x],mid+1,r,k);
pushup(x);
return x;
}
int merge(int x,int y)
{
if (!x || !y) return x|y;
int p=++tot;
res[p]=max(res[x],res[y]);
lc[p]=merge(lc[x],lc[y]);
rc[p]=merge(rc[x],rc[y]);
pushup(p);
return p;
}
}seg;
void dfs1(int x,int fa)
{
dep[x]=dep[fa]+1;
last[x]=sam.ins(last[fa],a[x]);
for (int i=head[x];~i;i=e[i].next)
dfs1(e[i].to,x);
}
void dfs2(int x)
{
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
dfs2(v);
rt[x]=seg.merge(rt[x],rt[v]);
}
rt[x]=seg.update(rt[x],1,2*n,sam.dfn[last[x]]);
ans=max(ans,seg.res[rt[x]]+dep[x]-1);
}
int main()
{
freopen("recollection.in","r",stdin);
freopen("recollection.out","w",stdout);
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (int i=2,x;i<=n;i++)
{
scanf("%d%d",&x,&a[i]);
add(x,i);
}
dfs1(1,0);
sam.adde();
sam.tot=0; sam.dfs(1);
dfs2(1);
printf("%d",ans);
return 0;
}