CF1276F Asterisk Substrings
一、题目
二、解法
肯定不能直接考虑带星号的母串,我们考虑最后的答案会长什么样子。
结合样例 \(1\),我们可以知道答案是这样几种情况:empty,*,s,*s,s*,s*t
前 \(5\) 种都好算到爆炸,就是本质不同的子串魔改一下就行了。
考虑最后一个怎么算,考虑两个本质不同的 s*t
要不然是 \(s\) 不同,要不然是 \(t\) 不同。所以相比于直接枚举前缀,我们可以枚举每一个 \(s\) 的等价类来保证 \(s\) 相同,这样就只需要处理 \(t\) 不同的条件了。
同一个等价类还有一个条件是出现位置相同,设这个等价类的出现位置为 \(p\),那么就是把 \(p_i+2\) 这些后缀全部取出来,我们有多少个本质不同后缀的前缀。这个可以放在 \(\tt sam\) 上解决,首先对反串建出后缀自动机,首先加上所有后缀的深度,然后考虑去重,把所有后缀按 \(\tt dfs\) 序排序,然后减去相邻两个 \(\tt lca\) 的深度即可。
要算所有 \(s\) 对应的 \(t\),可以考虑启发式合并,因为父亲能继承儿子的出现位置。我们用线段树维护反串的 \(\tt dfs\) 序,上传的时候就减去左儿子最右边的点和右儿子最左边的点的 \(\tt lca\) 的深度,那么时间复杂度 \(O(n\log n)\)
代码给我打吐了,看来是好久没写毒瘤码力下降了。
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
using namespace std;
const int M = 200005;
const int N = 30*M;
#define ll long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,Ind,cnt,dp[2*M][20],lg[2*M],st[M],dep[M],rt[M];
int p1[M],p2[M],ls[N],rs[N],ml[N],mr[N];char s[M];
ll ans,sum[N];
struct node
{
int fa,len,ch[26];
};
int Min(int x,int y)
{
return dep[x]<dep[y]?x:y;
}
int lca(int l,int r)
{
if(!l || !r) return 0;
int k=lg[r-l+1];
return Min(dp[l][k],dp[r-(1<<k)+1][k]);
}
void up(int x)
{
sum[x]=sum[ls[x]]+sum[rs[x]]-dep[lca(mr[ls[x]],ml[rs[x]])];
ml[x]=ml[ls[x]]?ml[ls[x]]:ml[rs[x]];
mr[x]=mr[rs[x]]?mr[rs[x]]:mr[ls[x]];
}
int merge(int x,int y)
{
if(!x || !y) return x|y;
int k=++cnt;
ls[k]=merge(ls[x],ls[y]);
rs[k]=merge(rs[x],rs[y]);
up(k);
return k;
}
void ins(int &x,int l,int r,int id)
{
if(!x) x=++cnt;
if(l==r)
{
sum[x]=dep[dp[id][0]];
ml[x]=mr[x]=id;
return ;
}
int mid=(l+r)>>1;
if(mid>=id) ins(ls[x],l,mid,id);
else ins(rs[x],mid+1,r,id);
up(x);
}
struct Sam
{
int cnt,last;node a[M];vector<int> g[M];
Sam() {cnt=last=1;}
void add(int c)
{
int p=last,np=last=++cnt;
a[np].len=a[p].len+1;
for(;p && !a[p].ch[c];p=a[p].fa) a[p].ch[c]=np;
if(!p) a[np].fa=1;
else
{
int q=a[p].ch[c];
if(a[q].len==a[p].len+1) a[np].fa=q;
else
{
int nq=++cnt;
a[nq]=a[q];a[nq].len=a[p].len+1;
a[q].fa=a[np].fa=nq;
for(;p && a[p].ch[c]==q;p=a[p].fa) a[p].ch[c]=nq;
}
}
}
void build()
{
for(int i=2;i<=cnt;i++)
{
int j=a[i].fa;
g[j].push_back(i);
}
}
//the second sam
void dfs(int u,int fa)
{
dep[u]=a[u].len;
st[u]=++Ind;dp[Ind][0]=u;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
dfs(v,u);
dp[++Ind][0]=u;
}
}
void init()
{
dfs(1,0);
for(int j=1;(1<<j)<=Ind;j++)
for(int i=1;i+(1<<j)-1<=Ind;i++)
dp[i][j]=Min(dp[i][j-1],dp[i+(1<<j-1)][j-1]);
for(int i=2;i<=Ind;i++) lg[i]=lg[i>>1]+1;
}
//the first sam
void work(int u)
{
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
work(v);
rt[u]=merge(rt[u],rt[v]);
}
}
void solve()
{
work(1);
for(int i=1;i<=cnt;i++)
ans+=1ll*(a[i].len-a[a[i].fa].len)*sum[rt[i]];
}
}A,B;
signed main()
{
scanf("%s",s+1),n=strlen(s+1);
for(int i=1;i<=n;i++)
{
A.add(s[i]-'a');p1[i]=A.last;
ans+=i-A.a[A.a[p1[i]].fa].len;
if(i==n-1) ans<<=1;
}
for(int i=n;i>=1;i--)
{
B.add(s[i]-'a');p2[i]=B.last;
if(i>1) ans+=B.a[p2[i]].len-B.a[B.a[p2[i]].fa].len;
}
ans+=2;
A.build();B.build();B.init();
for(int i=1;i<n-1;i++)
ins(rt[p1[i]],1,Ind,st[p2[i+2]]);
A.solve();
printf("%lld\n",ans);
}