「BZOJ 4502」串
「BZOJ 4502」串
题目描述
兔子们在玩字符串的游戏。首先,它们拿出了一个字符串集合 \(S\),然后它们定义一个字符串为“好”的,当且仅当它可以被分成非空的两段,其中每一段都是字符串集合 \(S\) 中某个字符串的前缀。比如对于字符串集合 \(\{ "abc","bca" \}\),字符串 \("abb"\),\("abab"\)是“好”的 \(("abb"="ab"+"b", abab="ab"+"ab")\) ,而字符串 \(“bc”\)不是“好”的。
兔子们想知道,一共有多少不同的“好”的字符串。
\(1 \leq N \leq 10000, 1 \leq |S| \leq 30\)
解题思路 :
观察发现,对于同一个串可能会有多种划分方式形成两个前缀拼接的形式,直接大力计算不方便处理重复的情况
此时不妨统计每一种答案串中最具有“特征”的那一种划分方式,在所有划分方式中,最小化第二个串的长度
也就是说,如果第一个串已经确定,第二个串的前缀与第一个串的公共部分全部划给第一个串
问题的一部分转化为一个 \(Trie\) 树上求 \(Borders\) 的问题,也就是 \(AC\) 自动机的 \(fail\) 指针,所以可以把问题规约到 \(AC\) 自动机上面
此时答案的形态有两种,拼接起来的串就是原串的一个前缀,或者是两个前缀拼接起来
考虑第一种情况,本质上是对于 \(AC\) 自动机中每一个 \(fail \neq root\) 的点,其到 \(root\) 的路径代表的前缀就是一个合法的答案
对于第二种情况,根据 \(AC\) 自动机的性质,匹配串和合法路径一一对应,所以问题可以转化为对合法路径计数
于是考虑在 \(AC\) 自动机上枚举第一个串,通过 \(dp\) 处理出每一个 \(Trie\) 树节点作为路径终点的答案,通过走树边和 \(fail\) 边来转移
设 \(f[i][j][k]\) 表示总长度 \(i\) 的串走到了节点 \(j\) ,枚举的第一个串的长度为 \(k\) 的答案
转移就直接走 \(Trie\) 图的边转移,但要保证任意时刻拼接起来的串长要能够等于 \(i\) ,也就是 \(dep(j) + k > i\)
但是这样的复杂度是 \(O(n\times60^2\times26)\) 的,时间复杂度不能够接受,考虑对状态进行简化
观察发现,将不等式稍微加变换就是 \(dep(j) > i - k\) ,那么只需要记录 \(f[i][j]\) 表示第二个串长为 \(i\) ,当前到达了节点 \(j\) 的方案数,现在复杂度是 \(O(n \times 26 \times 60)\)
/*program by mangoyang*/
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int f = 0, ch = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
if(f) x = -x;
}
#define int ll
#define par pair<int, int>
#define mp make_pair
#define fi first
#define se second
const int N = 3000005;
char s[N];
int f[65][N], n;
struct ACautomaton{
queue<int> q; int ch[N][26], dep[N], nxt[N][26], fail[N], size;
inline ACautomaton(){
for(int i = 0; i < 26; i++) nxt[0][i] = 1; size = 1;
}
inline int newnode(int x){ return dep[++size] = x, size; }
inline void ins(char *s){
int p = 1, len = strlen(s);
for(int i = 0; i < len; i++){
int c = s[i] - 'a';
if(!ch[p][c]) ch[p][c] = nxt[p][c] = newnode(i + 1);
p = ch[p][c];
}
}
inline void build(){
for(q.push(1); !q.empty(); ){
int u = q.front(); q.pop();
for(int i = 0; i < 26; i++){
int v = nxt[u][i];
if(!v) nxt[u][i] = nxt[fail[u]][i];
else fail[v] = nxt[fail[u]][i], q.push(v);
}
}
}
inline void solve(){
int ans = 0;
for(int i = 2; i <= size; i++) ans += (fail[i] != 1);
for(int i = 1; i <= size; i++)
for(int j = 0; j < 26; j++)
if(!ch[i][j] && nxt[i][j] != 1) f[1][nxt[i][j]]++;
for(int i = 1; i <= 60; i++)
for(int j = 1; j <= size; j++) if(f[i][j]){
for(int c = 0; c < 26; c++)
if(dep[nxt[j][c]] > i) f[i+1][nxt[j][c]] += f[i][j];
}
for(int i = 1; i <= 60; i++)
for(int j = 1; j <= size; j++) ans += f[i][j];
cout << ans;
}
}van;
signed main(){
read(n);
for(int i = 1; i <= n; i++) scanf("%s", s), van.ins(s);
van.build(), van.solve();
return 0;
}