LA3490 Generator(KMP + 高斯消元)
题意
一开始给你一个长为 \(S\) 的字符串。
从空串开始,不断在后面添加一个 \([A, A + n]\) 的一个字符。
第一次包含 \(S\) 的时候会停止添加。问期望的添加次数。
有 \(T\) 组数据。
\(T \le 10, |S| \le 12, n \le 26\)
题解
单模板匹配的直接用 \(\mathrm{KMP}\) 就可以了。
那么我们枚举 \(S\) 第 \(i\) 位 \(S_i\) ,然后枚举当前这位填的数 \(c\) ,那么就会转移到 \(S_{\delta (i, c)}\) 。(这个过程和普通匹配跳 \(fail\) 是一样的)
然后是期望,我们考虑倒推。令 \(dp_i\) 为当前匹配了前 \(i\) 位期望添加的字符才能匹配完。
那么显然有如下的转移:
- \(i = |S|: dp_i = 0\)
- \(i \not = |S|: dp_i = (\sum_{c} dp_{\delta(i, c)}) + 1\)
这样转移显然会出环。这种 \(dp\) 直接上高斯消元即可。
但是如果直接用 long double
做的话,虽然样例过得了,但是精度会被卡掉。
那有什么好办法吗?答案看起来一定是整数,那么我们显然想用 long long
解决。
前面消成上三角的时候,除的东西不能保证整除。
其中一种解决办法是用几个模数进行模意义下的消元,然后 \(CRT\) 合并即可。但是不太好写。
后来问了 zhou888 ,它告诉我一个神奇的做法,每次消去一行的时候,辗转相除,不断除掉共有的最多的那个就行了。
虽然多了个 \(\log n\) 的复杂度,但是确实好写啊。。。
然后复杂度就是 \(O(|S| \times n + |S|^3 \log n)\) 的。
代码
具体实现可以见代码。
#include <bits/stdc++.h>
#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl
using namespace std;
typedef long long ll;
template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; }
inline int read() {
int x(0), sgn(1); char ch(getchar());
for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
return x * sgn;
}
void File() {
#ifdef zjp_shadow
freopen ("3490.in", "r", stdin);
freopen ("3490.out", "w", stdout);
#endif
}
const int N = 14;
ll Mat[N][N];
void Gauss(int n) {
For (i, 1, n) {
For (j, i + 1, n) {
ll a = Mat[i][i], b = Mat[j][i];
while (b) {
ll tmp = a / b; a %= b; swap(a, b); swap(Mat[i], Mat[j]);
For (k, i, n + 1) Mat[j][k] -= tmp * Mat[i][k];
}
}
}
Fordown (i, n, 1) {
For (j, i + 1, n)
Mat[i][n + 1] -= Mat[i][j] * Mat[j][n + 1], Mat[i][j] = 0;
Mat[i][n + 1] /= Mat[i][i]; Mat[i][i] = 1;
}
}
int n, fail[N];
void Get_Fail(char *S) {
For (i, 2, strlen(S + 1)) {
int j = fail[i - 1];
while (j && S[i] != S[j + 1]) j = fail[j];
fail[i] = S[i] == S[j + 1] ? j + 1 : 0;
}
}
char str[N];
int main () {
File();
For (cases, 1, read()) {
int alpha = read(); scanf ("%s", str + 1);
int n = strlen(str + 1);
Get_Fail(str); Set(Mat, 0);
Mat[n + 1][n + 1] = alpha;
For (i, 0, n - 1) {
Mat[i + 1][i + 1] = Mat[i + 1][n + 2] = - alpha;
Rep (j, alpha) {
char cur = j + 'A';
int pos = i;
while (pos && str[pos + 1] != cur) pos = fail[pos];
if (str[pos + 1] == cur) ++ pos;
Mat[i + 1][pos + 1] += 1;
}
}
Gauss(n + 1);
printf ("Case %d:\n", cases);
printf ("%lld\n", Mat[1][n + 2]);
if (cases < casesend) putchar('\n');
}
return 0;
}