「bzoj 4180: 字符串计数」
真是一道好题
首先根据一个非常显然的贪心,如果给出了一个串\(S\),我们如何算最小操作次数呢
非常简单,我们直接把\(S\)拉到\(T\)的\(SAM\)上去跑,如果跑不动了就停下来,重新回到\(1\)继续跑
于是我们建出一个\(SAM\)之后可以写一个这样的暴力,设\(d[i][j][k]\)表示从\(i\)点到\(j\)点走\(i\)条边的最长路,对于那些走不动的边,我们可以接到\(1\)号节点对应的出边上去,边权为\(1\),其余的边权为\(0\),矩阵优化一下就是\(O(|T|^3logn)\)的复杂度
显然\(|T|\)并不允许我们开下如此之大的转移矩阵,尝试换一个角度来考虑这个问题
我们发现我们问题的本质就是最大化最小值,这是不是可以二分一下呢
于是现在的问题变成了对于一个二分出的操作次数\(mid\),判断答案是否能够更大
显然我们如果使用\(mid\)此操作构造出来的串长度小于\(n\),那么我们就可以断定答案可能会更大一些
于是又把问题转化成了利用\(mid\)次操作构造出来的字符串的最小长度
这个如何求呢,我们考虑一次操作无非就是从\(1\)的某一个出边指向的节点到另一个\(1\)的出边指向的节点,所以我们求出这些节点两两之间的最短路就好了
于是我们现在又可以利用矩阵转移了,复杂度\(O(|T|+|c|^3log^2n)\),\(|c|\)为字符集大小
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define LL long long
#define re register
#define inf 922337203685477580
#define maxn 200005
LL n;
char S[maxn>>1];
int m,lst=1,cnt=1,tot;
int len[maxn],fa[maxn],son[maxn][4],q[maxn],vis[maxn],c[maxn];
LL a[4][4],ans[4][4],t[4][4],d[maxn];
inline void ins(int c) {
int p=++cnt,f=lst;lst=p;
len[p]=len[f]+1;
while(f&&!son[f][c]) son[f][c]=p,f=fa[f];
if(!f) {fa[p]=1;return;}
int x=son[f][c];
if(len[f]+1==len[x]) {fa[p]=x;return;}
int y=++cnt;len[y]=len[f]+1;
fa[y]=fa[x],fa[x]=fa[p]=y;
son[y][0]=son[x][0];son[y][1]=son[x][1];
son[y][2]=son[x][2];son[y][3]=son[x][3];
while(f&&son[f][c]==x) son[f][c]=y,f=fa[f];
}
inline void did_t() {
LL mid[4][4];
for(re int i=0;i<4;i++)
for(re int j=0;j<4;j++) mid[i][j]=t[i][j],t[i][j]=inf;
for(re int k=0;k<4;k++)
for(re int i=0;i<4;i++)
for(re int j=0;j<4;j++)
t[i][j]=min(t[i][j],mid[i][k]+mid[k][j]);
}
inline void did_ans() {
LL mid[4][4];
for(re int i=0;i<4;i++)
for(re int j=0;j<4;j++) mid[i][j]=ans[i][j],ans[i][j]=inf;
for(re int k=0;k<4;k++)
for(re int i=0;i<4;i++)
for(re int j=0;j<4;j++)
ans[i][j]=min(ans[i][j],mid[i][k]+t[k][j]);
}
inline LL solve(LL now) {
for(re int i=0;i<4;i++)
for(re int j=0;j<4;j++) t[i][j]=a[i][j],ans[i][i]=inf;
for(re int i=0;i<4;i++) ans[i][i]=0;
LL b=now;
while(now) {if(now&1ll) did_ans();now>>=1ll;did_t();}
LL tmp=inf;
for(re int i=0;i<4;i++)
for(re int j=0;j<4;j++) tmp=min(tmp,ans[i][j]);
return tmp+b;
}
inline int check(LL now) {return solve(now)<=n;}
int main() {
scanf("%lld",&n);scanf("%s",S+1);m=strlen(S+1);
for(re int i=1;i<=m;i++) ins(S[i]-'A');
for(re int i=0;i<4;i++)
for(re int j=0;j<4;j++) a[i][j]=inf;
for(re int i=0;i<4;i++) {
tot=0;q[++tot]=son[1][i];
memset(vis,0,sizeof(vis));
memset(d,20,sizeof(d));d[q[1]]=0;
for(re int j=1;j<=tot;j++) {
int x=q[j];
for(re int k=0;k<4;k++) {
if(vis[son[x][k]]) continue;
if(!son[x][k]) a[i][k]=min(a[i][k],d[x]);
else vis[son[x][k]]=1,d[son[x][k]]=d[x]+1,q[++tot]=son[x][k];
}
}
}
LL l=1,r=n,ans=0;
while(l<=r) {
LL mid=l+r>>1ll;
if(check(mid)) l=mid+1,ans=mid;
else r=mid-1;
}
if(solve(ans)<n) ans++;
printf("%lld\n",ans);
return 0;
}