【学习笔记】Berlekamp-Massey
orz zhenzhendong
之前贺过一边周指导博客但是弃疗了, 今天又贺了一次.
问题描述
给定一个长度为 \(n\) 的数列 \(\{a_i\}\), 求一个最短的齐次线性递推数列\(\{b_i\}\)(设长度为 \(m\)),使得对于所有 \(m \leq k \leq n\), 有 \(a_k = \sum_{i = 1} ^ m a_{k - i} b_i\)
复杂度要求: \(O(n ^ 2)\)
一看就很适合配合常系数齐次线性递推食用.
算法流程
增量构造.
假设我们当前已经求出了 \(a_{0...i - 1}\) 的线性递推数列. 计算过程中, 我们曾经得出过 \(c\) 个递推式, 第 \(i\) 个递推式在 \(fail_i\) 的位置第一次失效了.
一开始 \(c = 0\), 我们有一个空的递推式.
现在我们加入数 \(a_i\).
设 \(R_c\) 的长度为 \(m\), \(delta_i = a_i - \sum_{k = 1} ^ {m} a_{i - k} R_c(k)\)
如果 \(delta_i = 0\), \(R_c\) 仍是一个合法的递推式.
否则我们要对 \(R_c\) 做出调整, 来得到一个新的符合条件的递推式.
若 \(c = 0\), 那么前 \(i - 1\) 个数都是 \(0\). 我们只需要构造一个包含 \(i\) 个 \(0\) 的递推式即可.
若 \(c \not= 0\), 只需要构造一个递推式 \(R'\), 当 \(|R'| + 1 \leq k < n\) 时, \(\sum_{i = 1} ^ {|R'|} a_{k - i} R'_i = 0\), \(\sum_{i = 1} ^ {|R'| } a_{n - i} R'_i = delta_n\), 那么 \(R_{c + 1} = R_c + R'\) 就符合条件.
我们随便找一个 $ 0 \leq id < c$, 它的前 \(fail_{id} - 1\) 个数都是 \(0\). 如果我们对它作一个位移, 即前面补上 \(i - fail_{id} - 1\) 个 \(0\), 后面跟个 \(1\), 然后接上 \(-R_{id}\), 我们就可以得到一个只有位置 \(i\) 是 \(delta_{fail_{id}}\),其余位置都是 \(0\) 的数组. 然后我们把它整个乘上 \(tmp = \frac{delta_{i}}{delta_{fail_{id}}}\), 就构造出了 \(R'\)
也就是 \(R'\) 为
然后我们还要保证 \(R_c + R'\) 是最短的, 我们找到 \(i - fail_{id} + |R_{id}|\) 最小的即可. (不会严格证明)
模板
数据可以去周指导的博客上看.
#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define LL long long
#define pii pair<int, int>
using namespace std;
const int mod = 1e9 + 7;
template <typename T> T read(T &x) {
int f = 0;
register char c = getchar();
while (c > '9' || c < '0') f |= (c == '-'), c = getchar();
for (x = 0; c >= '0' && c <= '9'; c = getchar())
x = (x << 3) + (x << 1) + (c ^ 48);
if (f) x = -x;
return x;
}
inline void upd(int &x, int y) {
(x += y) >= mod ? x -= mod : 0;
}
inline int add(int x, int y) {
return (x += y) >= mod ? x - mod : x;
}
inline int dec(int x, int y) {
return (x -= y) < 0 ? x + mod : x;
}
inline int Qpow(int x, int p) {
int ans = 1;
for (; p; p >>= 1) {
if (p & 1) ans = 1LL * ans * x % mod;
x = 1LL * x * x % mod;
}
return ans;
}
inline int Inv(int x) {
return Qpow(x, mod - 2);
}
namespace BM {
const int Maxn = 5005;
int n, c;
int a[Maxn], del[Maxn], fail[Maxn];
vector<int> R[Maxn];
vector<int> solve() {
c = 0;
for (int i = 1; i <= n; ++i) {
if (c == 0) {
if (a[i]) {
fail[0] = i;
++c;
del[i] = a[i];
R[c].resize(i);
}
continue;
}
del[i] = a[i];
for (int j = 0; j < R[c].size(); ++j) {
del[i] = dec(del[i], 1LL * R[c][j] * a[i - j - 1] % mod);
}
if (del[i] == 0) continue;
fail[c] = i;
int id = c - 1, v = i - fail[id] + R[id].size();
for (int j = c - 1; j >= 0; --j) {
if (v > i - fail[j] + R[j].size()) {
v = i - fail[j] + R[j].size();
id = j;
}
}
int p = i - fail[id];
int tmp = 1LL * del[i] * Inv(del[fail[id]]) % mod;
R[c + 1] = R[c];
if (R[c + 1].size() < v) R[c + 1].resize(v);
upd(R[c + 1][p - 1], tmp);
for (int j = 0; j < R[id].size(); ++j) {
upd(R[c + 1][p + j], -1LL * tmp * R[id][j] % mod + mod);
}
++c;
}
if (c == 0) return vector<int>(0);
return R[c];
}
}
using namespace BM;
int main() {
read(n);
for (int i = 1; i <= n; ++i) read(a[i]);
vector<int> ans = BM::solve();
cout << ans.size() << endl;
for (int i = 0; i < ans.size(); ++i)
cout << ans[i] << ' ';
puts("");
}