[NOI2016]优秀的拆分
好像经典的做法是关键点+调和级数,但是不太会;于是就打个SAM+暴力启发式合并
我们只需要算一下每个\(i\)是多少个AA串的结尾,是多少个BB串的开头,我们发现这两个问题其实是等价的;我们算每个\(i\)是多少个AA串的结尾即可;
其实我们只需要对于每个\(i\)算有多少\(j<i\)满足\(i-j\leq{\rm LCS}(i,j)\)即可;
大力启发式合并,考虑把小集合合并到大集合上产生的影响,发现非常好考虑,就是一个区间加和一个区间查,线段树就好了;
复杂度\(O(n\log^2n)\),代码
#include<bits/stdc++.h>
#define re register
#define LL long long
const int maxn=6e4+5;
const int M=maxn*20;
char S[maxn>>1];
int len[maxn],son[maxn][26],fa[maxn];
int tax[maxn>>1],A[maxn],n,lst,cnt,tot,tmp;
int a[maxn>>1],b[maxn>>1],ans[maxn>>1],g[maxn>>1];
int l[M],r[M],d[M],tag[M],rt[maxn],pos[maxn>>1];
inline void ins(int c) {
int p=++cnt,f=lst;lst=p;len[p]=len[f]+1;
while(f&&!son[f][c])son[f][c]=p,f=fa[f];
if(!f){fa[p]=1;return;}int x=son[f][c];
if(len[f]+1==len[x]){fa[p]=x;return;}
int y=++cnt;len[y]=len[f]+1,fa[y]=fa[x],fa[x]=fa[p]=y;
for(re int i=0;i<26;i++)son[y][i]=son[x][i];
while(f&&son[f][c]==x)son[f][c]=y,f=fa[f];
}
inline int newnode() {++tot;l[tot]=r[tot]=d[tot]=tag[tot]=0;return tot;}
int chg(int nw,int x,int y,int pos) {
if(!nw)nw=newnode();d[nw]++;if(x==y)return nw;int mid=x+y>>1;
(pos<=mid?l[nw]=chg(l[nw],x,mid,pos):r[nw]=chg(r[nw],mid+1,y,pos));return nw;
}
void add(int nw,int x,int y,int lx,int ry) {
if(!nw)return;if(lx<=x&&ry>=y){tag[nw]++;return;}int mid=x+y>>1;
if(lx<=mid)add(l[nw],x,mid,lx,ry);if(ry>mid)add(r[nw],mid+1,y,lx,ry);
}
void dfs(int nw,int x,int y,int v) {
if(!nw)return;v+=tag[nw];if(x==y){a[++tmp]=x;b[tmp]=ans[x]+v;return;}
int mid=x+y>>1;dfs(l[nw],x,mid,v);dfs(r[nw],mid+1,y,v);
}
int mof(int nw,int x,int y,int pos,int v) {
if(!nw)nw=newnode();d[nw]++,v-=tag[nw];if(x==y){ans[x]=v;return nw;}int mid=x+y>>1;
(pos<=mid?l[nw]=mof(l[nw],x,mid,pos,v):r[nw]=mof(r[nw],mid+1,y,pos,v));return nw;
}
int ask(int nw,int x,int y,int lx,int ry) {
if(!nw)return 0;if(lx<=x&&ry>=y)return d[nw];int mid=x+y>>1;
return (lx<=mid?ask(l[nw],x,mid,lx,ry):0)+(ry>mid?ask(r[nw],mid+1,y,lx,ry):0);
}
inline void Main() {
lst=cnt=1;tot=0;
memset(len,0,sizeof(len));memset(son,0,sizeof(son));
memset(fa,0,sizeof(fa));memset(tax,0,sizeof(tax));
memset(rt,0,sizeof(rt));memset(ans,0,sizeof(ans));
for(re int i=1;i<=n;i++)ins(S[i]-'a'),pos[i]=lst;
for(re int i=1;i<=cnt;i++)tax[len[i]]++;
for(re int i=1;i<=n;i++)tax[i]+=tax[i-1];
for(re int i=1;i<=cnt;i++)A[tax[len[i]]--]=i;
for(re int i=1;i<=n;i++)rt[pos[i]]=chg(rt[pos[i]],1,n,i);
for(re int i=cnt;i>1;--i) {
int x=A[i],f=fa[x];
if(d[rt[x]]>d[rt[f]])std::swap(rt[x],rt[f]);
tmp=0;dfs(rt[x],1,n,0);
for(re int j=1;j<=tmp;j++) {
if(len[f]) b[j]+=ask(rt[f],1,n,a[j]-len[f],a[j]-1);
if(len[f]) add(rt[f],1,n,a[j]+1,a[j]+len[f]);
}
for(re int j=1;j<=tmp;j++) rt[f]=mof(rt[f],1,n,a[j],b[j]);
}
tmp=0;dfs(rt[1],1,n,0);
for(re int i=1;i<=tmp;i++)ans[a[i]]=b[i];
}
int main() {
int T;scanf("%d",&T);
for(LL Ans;T;--T) {
scanf("%s",S+1);n=strlen(S+1);Main();
for(re int i=1;i<=n;i++)g[i]=ans[i];
std::reverse(S+1,S+n+1);Main();
std::reverse(ans+1,ans+n+1);Ans=0;
for(re int i=1;i<n;i++)Ans+=1ll*g[i]*ans[i+1];
printf("%lld\n",Ans);
}
return 0;
}