Luogu4218 [CTSC2010]珠宝商
Luogu4218 [CTSC2010]珠宝商
\(SAM\)+点分治+根号分治
又被\(SAM\)神仙题教育了一顿。。。
字符串上树了,看起来很不可做的样子。由于本题需要维护的是所有链的信息,容易想到点分治。
既然要使用点分治,不可避免地要面临统计子树信息和两条链合并的难题。
比如当前连通块的重心为\(x\),要统计信息并合并\(u \rightarrow x \rightarrow v\),如何操作呢?
我们统计答案时,必然需要把信息集中在合并点处,也就是说,链\(u \rightarrow x\)对应模式串\(T\)的区间\([l,i]\),而链\(x \rightarrow v\)对应模式串\(T\)的区间\([i,r]\),那么我们可以在\(i\)处统计贡献。
考虑\(SAM\)上只有后缀会拥有一个指定的节点,我们先考虑\(u \rightarrow x\)的一部分,它将作为\([l,i]\)出现,而节点\(x\)将作为\(i\)位置。
发现我们遍历子树时,不是我们习惯的向后插入字符,而是向前插入字符。
但是这仅仅是查找而已,强大的\(SAM\)可以完成这个任务。具体来说,我们考虑原串一定是新串的后缀,那么对应于\(parent\)树上,原串所在位置\(p\)与新串所在位置相同或是它的父亲。当长度小于\(len_p\)时,要么失配(只要记录\(p\)的任意一个后缀位置即可),要么长度\(+1\);而长度大于\(len_p\)时,它可能失配,也可能到达它的其中一个儿子节点。我们需要建立一种全新的转移边,表示一个节点到达最大长度之后加入字符\(c\)会到达哪个儿子,可以在建\(SAM\)时预处理(不明白可以看代码,很容易理解)。
而链\(x \rightarrow v\)只是\(S\)的反串\(S^r\)上同样的操作而已。
接下来,就是合并答案了。
对于\(i\)位置,我们会在\(S\)的\(SAM\)上记录其后缀,在\(S^r\)的\(SAM\)上记录其前缀,我们发现只要是\(i\)的前缀、后缀都可以记录入答案,也就是\(i\)所在的节点以及其\(parent\)树上的祖先,我们可以最后一次处理统计答案,一次计算时间复杂度为\(O(m+size)\)。
但是点分治需要有去重,直接去重一次时间复杂度仍然是\(O(m+size)\)。
我们一分析复杂度\(O(n \log n + nm)\),白干了???
根号分治就出场了。
除了\(O(m+size)\)的统计算法,我们显然还有\(O(size^2)\)的暴力。
我们设定一个阈值\(B=\sqrt n\),使得在\(\le B\)时跑暴力,\(>B\)时跑\(O(m+size)\)的算法。
由于点分治一层子树大小必然比上一层至少减半,那么我们计算一下,在最劣情况下,点分树呈现二叉树的形式,第\(k\)层的子树大小不会超过\(\frac{n}{2^k}\),在本题数据范围中第\(8\)层的数据已经小于\(\sqrt n\)了,这时候子树个数上限大约是\(2^8=256\),可以近似看成\(\sqrt n\)了。所以这一部分的复杂度为\(O(n \sqrt n)\)。
对于大子树,可以把小子树看成它们的叶子节点,那么大子树个数与小子树个数同阶,这一部分时间复杂度为\(O(m \sqrt n)\)。
现在我们的去重仍然存在时间复杂度问题。
我们需要观察重复的是哪一部分,也就是说原来的\(u \rightarrow x \rightarrow v\)的\(u,v\)在同一子树中。
同样进行根号分治,对于大子树仍然执行\(O(m+size)\)算法。而小子树可以\(O(size)\)枚举子树中一个点,记录其到达当前点分到的重心节点的路径,然后返回同一子树中进行暴力统计,时间复杂度为\(O(size^2)\)。
进行平衡后,时间复杂度为\(O(n \sqrt n)\),可以轻松通过此题。
\(Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define ll long long
#define N 50005
using namespace std;
const int INF=1000000007;
char s[N],t[N];
int n,m,x,y,block,rc[N],c[N];
ll ans=0;
struct edge
{
int nxt,v;
edge (int Nxt=0,int V=0)
{
nxt=Nxt,v=V;
}
}e[N << 1];
int tot,fr[N];
int lsz,rt,rtsz,sz[N],f[N],q[N];
bool vis[N];
void add(int x,int y)
{
++tot;
e[tot]=edge(fr[x],y),fr[x]=tot;
}
struct SAM
{
int cnt=1,lst=1,tr[N << 1][26],pre[N << 1],len[N << 1],R[N << 1];
int g[N << 1],a[N << 1],loc[N],ct[N << 1],son[N << 1][26];
char t[N];
void ins(int c,int I)
{
int p=lst,np=++cnt;
loc[I]=lst=np,len[np]=len[p]+1,R[np]=I,++ct[np];
for (;p && !tr[p][c];p=pre[p])
tr[p][c]=np;
if (!p)
pre[np]=1; else
{
int q=tr[p][c];
if (len[p]+1==len[q])
pre[np]=q; else
{
int g=++cnt;
memcpy(tr[g],tr[q],sizeof(tr[q]));
len[g]=len[p]+1,pre[g]=pre[q],R[g]=R[q];
for (;p && tr[p][c]==q;p=pre[p])
tr[p][c]=g;
pre[q]=pre[np]=g;
}
}
}
void build()
{
for (int i=1;i<=m;++i)
ins(t[i]-'a',i);
memset(c,0,(m+1)*sizeof(int));
for (int i=1;i<=cnt;++i)
++c[len[i]];
for (int i=1;i<=m;++i)
c[i]+=c[i-1];
for (int i=1;i<=cnt;++i)
a[c[len[i]]--]=i;
for (int i=cnt;i>=2;--i)
{
int u=a[i];
ct[pre[u]]+=ct[u];
if (pre[u]!=1)
son[pre[u]][t[R[u]-len[pre[u]]]-'a']=u;
}
}
int nxt(int u,int L,char c)
{
if (len[u]>=L)
return (t[R[u]-L+1]==c)?u:0;
return son[u][c-'a'];
}
void clear()
{
memset(g,0,(cnt+1)*sizeof(int));
}
void dfs(int u,int F,int Len,int st)
{
if (!st)
return;
++g[st];
for (int i=fr[u];i;i=e[i].nxt)
{
int v=e[i].v;
if (v==F || vis[v])
continue;
dfs(v,u,Len+1,nxt(st,Len+1,s[v]));
}
}
void calc()
{
for (int i=2;i<=cnt;++i)
g[a[i]]+=g[pre[a[i]]];
}
}s1,s2;
void findrt(int u,int F,int rn)
{
int mx=-1;
sz[u]=1;
for (int i=fr[u];i;i=e[i].nxt)
{
int v=e[i].v;
if (v==F || vis[v])
continue;
findrt(v,u,rn);
sz[u]+=sz[v],mx=max(mx,sz[v]);
}
mx=max(mx,rn-sz[u]);
if (mx<rtsz)
rtsz=mx,rt=u;
}
void getrt(int u,int rn)
{
rtsz=INF,findrt(u,0,rn);
}
void dfs1(int u,int F)
{
q[++q[0]]=u;
for (int i=fr[u];i;i=e[i].nxt)
{
int v=e[i].v;
if (v==F || vis[v])
continue;
f[v]=u;
dfs1(v,u);
}
}
void dfs2(int u,int F,int st,int t)
{
if (!st)
return;
ans+=t*s1.ct[st];
for (int i=fr[u];i;i=e[i].nxt)
{
int v=e[i].v;
if (v==F || vis[v])
continue;
dfs2(v,u,s1.tr[st][rc[v]],t);
}
}
void calc(int u,int c,int t)
{
s1.clear(),s2.clear();
if (t==1)
{
s1.dfs(u,0,1,s1.tr[1][rc[u]]);
s2.dfs(u,0,1,s2.tr[1][rc[u]]);
} else
{
s1.dfs(u,0,2,s1.nxt(s1.tr[1][c],2,s[u]));
s2.dfs(u,0,2,s2.nxt(s2.tr[1][c],2,s[u]));
}
s1.calc(),s2.calc();
for (int i=1;i<=m;++i)
ans+=(ll)t*s1.g[s1.loc[i]]*s2.g[s2.loc[m-i+1]];
}
void calc2(int u,int rt)
{
q[0]=0,f[u]=rt,dfs1(u,0);
for (int i=1;i<=q[0];++i)
{
int t=q[i],st=1;
while (t!=rt)
st=s1.tr[st][rc[t]],t=f[t];
st=s1.tr[st][rc[t]],st=s1.tr[st][rc[u]];
dfs2(u,0,st,-1);
}
}
void solve(int u)
{
if (lsz<=block)
{
q[0]=0,dfs1(u,0);
for (int i=1;i<=q[0];++i)
dfs2(q[i],0,s1.tr[1][rc[q[i]]],1);
return;
}
int tsz=lsz;
vis[u]=true;
calc(u,1,1);
for (int i=fr[u];i;i=e[i].nxt)
{
int v=e[i].v;
if (vis[v])
continue;
int ts=(sz[u]<sz[v])?tsz-sz[u]:sz[v];
if (ts>block)
calc(v,rc[u],-1); else
calc2(v,u);
}
for (int i=fr[u];i;i=e[i].nxt)
{
int v=e[i].v;
if (vis[v])
continue;
lsz=(sz[u]<sz[v])?tsz-sz[u]:sz[v];
getrt(v,lsz);
solve(rt);
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<n;++i)
{
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
scanf("%s%s",s+1,t+1);
for (int i=1;i<=n;++i)
rc[i]=s[i]-'a';
block=(int)sqrt(n);
for (int i=1;i<=m;++i)
s1.t[i]=t[i],s2.t[i]=t[m-i+1];
s1.build(),s2.build();
lsz=n,getrt(1,n);
solve(rt);
printf("%lld\n",ans);
return 0;
}