Luogu4218 [CTSC2010]珠宝商

Luogu4218 [CTSC2010]珠宝商

\(SAM\)+点分治+根号分治

又被\(SAM\)神仙题教育了一顿。。。

字符串上树了,看起来很不可做的样子。由于本题需要维护的是所有链的信息,容易想到点分治。

既然要使用点分治,不可避免地要面临统计子树信息和两条链合并的难题。

比如当前连通块的重心为\(x\),要统计信息并合并\(u \rightarrow x \rightarrow v\),如何操作呢?

我们统计答案时,必然需要把信息集中在合并点处,也就是说,链\(u \rightarrow x\)对应模式串\(T\)的区间\([l,i]\),而链\(x \rightarrow v\)对应模式串\(T\)的区间\([i,r]\),那么我们可以在\(i\)处统计贡献。

考虑\(SAM\)上只有后缀会拥有一个指定的节点,我们先考虑\(u \rightarrow x\)的一部分,它将作为\([l,i]\)出现,而节点\(x\)将作为\(i\)位置。

发现我们遍历子树时,不是我们习惯的向后插入字符,而是向前插入字符。

但是这仅仅是查找而已,强大的\(SAM\)可以完成这个任务。具体来说,我们考虑原串一定是新串的后缀,那么对应于\(parent\)树上,原串所在位置\(p\)与新串所在位置相同或是它的父亲。当长度小于\(len_p\)时,要么失配(只要记录\(p\)的任意一个后缀位置即可),要么长度\(+1\);而长度大于\(len_p\)时,它可能失配,也可能到达它的其中一个儿子节点。我们需要建立一种全新的转移边,表示一个节点到达最大长度之后加入字符\(c\)会到达哪个儿子,可以在建\(SAM\)时预处理(不明白可以看代码,很容易理解)。

而链\(x \rightarrow v\)只是\(S\)的反串\(S^r\)上同样的操作而已。

接下来,就是合并答案了。

对于\(i\)位置,我们会在\(S\)\(SAM\)上记录其后缀,在\(S^r\)\(SAM\)上记录其前缀,我们发现只要是\(i\)的前缀、后缀都可以记录入答案,也就是\(i\)所在的节点以及其\(parent\)树上的祖先,我们可以最后一次处理统计答案,一次计算时间复杂度为\(O(m+size)\)

但是点分治需要有去重,直接去重一次时间复杂度仍然是\(O(m+size)\)

我们一分析复杂度\(O(n \log n + nm)\),白干了???

根号分治就出场了。

除了\(O(m+size)\)的统计算法,我们显然还有\(O(size^2)\)的暴力。

我们设定一个阈值\(B=\sqrt n\),使得在\(\le B\)时跑暴力,\(>B\)时跑\(O(m+size)\)的算法。

由于点分治一层子树大小必然比上一层至少减半,那么我们计算一下,在最劣情况下,点分树呈现二叉树的形式,第\(k\)层的子树大小不会超过\(\frac{n}{2^k}\),在本题数据范围中第\(8\)层的数据已经小于\(\sqrt n\)了,这时候子树个数上限大约是\(2^8=256\),可以近似看成\(\sqrt n\)了。所以这一部分的复杂度为\(O(n \sqrt n)\)

对于大子树,可以把小子树看成它们的叶子节点,那么大子树个数与小子树个数同阶,这一部分时间复杂度为\(O(m \sqrt n)\)

现在我们的去重仍然存在时间复杂度问题。

我们需要观察重复的是哪一部分,也就是说原来的\(u \rightarrow x \rightarrow v\)\(u,v\)在同一子树中。

同样进行根号分治,对于大子树仍然执行\(O(m+size)\)算法。而小子树可以\(O(size)\)枚举子树中一个点,记录其到达当前点分到的重心节点的路径,然后返回同一子树中进行暴力统计,时间复杂度为\(O(size^2)\)

进行平衡后,时间复杂度为\(O(n \sqrt n)\),可以轻松通过此题。

\(Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define ll long long
#define N 50005
using namespace std;
const int INF=1000000007;
char s[N],t[N];
int n,m,x,y,block,rc[N],c[N];
ll ans=0;
struct edge
{
    int nxt,v;
    edge (int Nxt=0,int V=0)
    {
        nxt=Nxt,v=V;
    }
}e[N << 1];
int tot,fr[N];
int lsz,rt,rtsz,sz[N],f[N],q[N];
bool vis[N];
void add(int x,int y)
{
    ++tot;
    e[tot]=edge(fr[x],y),fr[x]=tot;
}
struct SAM
{
    int cnt=1,lst=1,tr[N << 1][26],pre[N << 1],len[N << 1],R[N << 1];
    int g[N << 1],a[N << 1],loc[N],ct[N << 1],son[N << 1][26];
    char t[N];
    void ins(int c,int I)
    {
        int p=lst,np=++cnt;
        loc[I]=lst=np,len[np]=len[p]+1,R[np]=I,++ct[np];
        for (;p && !tr[p][c];p=pre[p])
            tr[p][c]=np;
        if (!p)
            pre[np]=1; else
            {
                int q=tr[p][c];
                if (len[p]+1==len[q])
                    pre[np]=q; else
                    {
                        int g=++cnt;
                        memcpy(tr[g],tr[q],sizeof(tr[q]));
                        len[g]=len[p]+1,pre[g]=pre[q],R[g]=R[q];
                        for (;p && tr[p][c]==q;p=pre[p])
                            tr[p][c]=g;
                        pre[q]=pre[np]=g;
                    }
            }
    }
    void build()
    {
        for (int i=1;i<=m;++i)
            ins(t[i]-'a',i);
        memset(c,0,(m+1)*sizeof(int));
        for (int i=1;i<=cnt;++i)
            ++c[len[i]];
        for (int i=1;i<=m;++i)
            c[i]+=c[i-1];
        for (int i=1;i<=cnt;++i)
            a[c[len[i]]--]=i;
        for (int i=cnt;i>=2;--i)
        {
            int u=a[i];
            ct[pre[u]]+=ct[u];
            if (pre[u]!=1)
                son[pre[u]][t[R[u]-len[pre[u]]]-'a']=u;
        }
    }
    int nxt(int u,int L,char c)
    {
        if (len[u]>=L)
            return (t[R[u]-L+1]==c)?u:0;
        return son[u][c-'a'];
    }
    void clear()
    {
        memset(g,0,(cnt+1)*sizeof(int));
    }
    void dfs(int u,int F,int Len,int st)
    {
        if (!st)
            return;
        ++g[st];
        for (int i=fr[u];i;i=e[i].nxt)
        {
            int v=e[i].v;
            if (v==F || vis[v])
                continue;
            dfs(v,u,Len+1,nxt(st,Len+1,s[v]));
        }
    }
    void calc()
    {
        for (int i=2;i<=cnt;++i)
            g[a[i]]+=g[pre[a[i]]];
    }
}s1,s2;
void findrt(int u,int F,int rn)
{
    int mx=-1;
    sz[u]=1;
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==F || vis[v])
            continue;
        findrt(v,u,rn);
        sz[u]+=sz[v],mx=max(mx,sz[v]);
    }
    mx=max(mx,rn-sz[u]);
    if (mx<rtsz)
        rtsz=mx,rt=u;
}
void getrt(int u,int rn)
{
    rtsz=INF,findrt(u,0,rn);
}
void dfs1(int u,int F)
{
    q[++q[0]]=u;
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==F || vis[v])
            continue;
        f[v]=u;
        dfs1(v,u);
    }
}
void dfs2(int u,int F,int st,int t)
{
    if (!st)
        return;
    ans+=t*s1.ct[st];
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (v==F || vis[v])
            continue;
        dfs2(v,u,s1.tr[st][rc[v]],t);
    }
}
void calc(int u,int c,int t)
{
    s1.clear(),s2.clear();
    if (t==1)
    {
        s1.dfs(u,0,1,s1.tr[1][rc[u]]);
        s2.dfs(u,0,1,s2.tr[1][rc[u]]);
    } else
    {
        s1.dfs(u,0,2,s1.nxt(s1.tr[1][c],2,s[u]));
        s2.dfs(u,0,2,s2.nxt(s2.tr[1][c],2,s[u]));
    }
    s1.calc(),s2.calc();
    for (int i=1;i<=m;++i)
        ans+=(ll)t*s1.g[s1.loc[i]]*s2.g[s2.loc[m-i+1]];
}
void calc2(int u,int rt)
{
    q[0]=0,f[u]=rt,dfs1(u,0);
    for (int i=1;i<=q[0];++i)
    {
        int t=q[i],st=1;
        while (t!=rt)
            st=s1.tr[st][rc[t]],t=f[t];
        st=s1.tr[st][rc[t]],st=s1.tr[st][rc[u]];
        dfs2(u,0,st,-1);
    }
}
void solve(int u)
{
    if (lsz<=block)
    {
        q[0]=0,dfs1(u,0);
        for (int i=1;i<=q[0];++i)
            dfs2(q[i],0,s1.tr[1][rc[q[i]]],1);
        return;
    }
    int tsz=lsz;
    vis[u]=true;
    calc(u,1,1);
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (vis[v])
            continue;
        int ts=(sz[u]<sz[v])?tsz-sz[u]:sz[v];
        if (ts>block)
            calc(v,rc[u],-1); else
            calc2(v,u);
    }
    for (int i=fr[u];i;i=e[i].nxt)
    {
        int v=e[i].v;
        if (vis[v])
            continue;
        lsz=(sz[u]<sz[v])?tsz-sz[u]:sz[v];
        getrt(v,lsz);
        solve(rt);
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<n;++i)
    {
        scanf("%d%d",&x,&y);
        add(x,y),add(y,x);
    }
    scanf("%s%s",s+1,t+1);
    for (int i=1;i<=n;++i)
        rc[i]=s[i]-'a';
    block=(int)sqrt(n);
    for (int i=1;i<=m;++i)
        s1.t[i]=t[i],s2.t[i]=t[m-i+1];
    s1.build(),s2.build();
    lsz=n,getrt(1,n);
    solve(rt);
    printf("%lld\n",ans);
    return 0;
}
posted @ 2021-02-05 08:02  GK0328  阅读(63)  评论(0编辑  收藏  举报