AC自动机+矩阵加速--P3502 [POI2010]CHO-Hamsters
problem
给出 \(n\) 个互不包含的字符串,要求你求出一个最短的字符串 \(S\),使得这 \(n\) 个字符串在 \(S\) 中总共至少出现 \(m\) 次,问 \(S\) 最短是多少。
solution
- 我们首先转化题意:
有 \(n\) 个点,两个点 \(i,j\) 之间的权值为将第 \(j\) 种字符串接在第 \(i\) 种结尾的最小增加量(注意此处 \(i\) 可以和 \(j\) 相等)。求一条经过 \(m\) 个点的最短路径,起点和终点自己选择。
暴力求法很好想,因为 \(n\leq 200\) ,所以考虑Floyd。设 \(f_{k,i,j}\)为 \(i\) 到 \(j\) 经过 \(k\) 个点的最短路径,然后一层一层枚举,直到 \(m\) ,就求出来了。
- 矩阵优化
设 \(dp_{i,j}\)表示现在串里有 \(i\) 个字符串,并以第 \(j\) 种字符串结尾的最短长度, \(dis_{i,j}\)表示将第 \(j\) 种字符串接在第 \(i\) 种结尾的最小增加量。
那么转移方程也很好得到: \(dp_{i,j}=\min\{dp_{i,k}+dis_{k,j}\}\)
- 实现细节
建AC自动机,预处理出两两结尾之间最短距离,矩阵\(a_{i,j}\)最开始表示\(i\)到\(j\)中间无其他字符结尾的最短距离
因为最开始的矩阵中每个\(a_{i,j}\)都包含两个结束点,所以我们需要求\(m-1\)次,其次我们统计的是终点到终点的距离,第一个字符串的长度没有计算,需要单独加上
code
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <algorithm>
#include <vector>
#define ll long long
using namespace std;
int read(){
int x = 1,a = 0;char ch = getchar();
while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
while (ch >= '0'&&ch <= '9'){a =a*10+ch-'0';ch = getchar();}
return x*a;
}
const int maxn = 2e6+10;
int n,m;
char c[maxn];
vector<int> g;
int cnt = 1,trie[maxn][30],l[maxn];
void build(char *a,int pos){
int len = strlen(a),root = 1;
for (int i = 0;i < len;i++){
int nxt = a[i]-'a';
if (!trie[root][nxt]) trie[root][nxt] = ++cnt;
root = trie[root][nxt];
}
g.push_back(root);l[root] = len;
}
int fail[maxn],dis[maxn],vis[maxn];
void getfail(){
queue<int> q;q.push(1);
for (int i = 0;i < 26;i++) trie[0][i] = 1;
while (!q.empty()){
int root = q.front();q.pop();
for (int i = 0;i < 26;i++){
if (!trie[root][i]) trie[root][i] = trie[fail[root]][i];
else{
fail[trie[root][i]] = trie[fail[root]][i];
q.push(trie[root][i]);
}
}
}
}
void bfs(int s){
memset(vis,0,sizeof(vis));
queue<int> q;q.push(s);dis[s] = 0;
while (!q.empty()){
int x = q.front();q.pop();
for (int i = 0;i < 26;i++){
int to = trie[x][i];
if (vis[to]) continue;
dis[to] = dis[x] + 1;vis[to] = 1;
q.push(to);
}
}
}
struct node{
ll a[205][205];
}I,a;
node operator * (const node x,const node y){
node z;
memset(z.a,63,sizeof(z.a));
for (int i = 0;i < n;i++){
for (int j = 0;j < n;j++){
for (int k = 0;k < n;k++){
z.a[i][j] = min(z.a[i][j],x.a[i][k]+y.a[k][j]);
}
}
}
return z;
}
void qpow(node aa,int k){
if(!k) return;
I = aa,k--;
while (k){
if (k&1) I = I*aa;
aa = aa*aa;
k >>= 1;
}
}
signed main(){
n = read(),m = read();
if(m == 0) {printf("0");return 0;}
for (int i = 1;i <= n;i++){
scanf ("%s",c);
build(c,i);
}
getfail();
for (int i = 0;i < g.size();i++){
bfs(g[i]);
for (int j = 0;j < g.size();j++){
a.a[i][j] = dis[g[j]];
}
}
ll ans = 1e18+7;
qpow(a,m-1);
for (int i = 0;i < g.size();i++){
for (int j = 0;j < g.size();j++){
ans = min(ans,I.a[i][j] + l[g[i]]);
}
}
printf("%lld\n",ans);
return 0;
}