【学习笔记】Berlekamp-Massey

orz zhenzhendong

之前贺过一边周指导博客但是弃疗了, 今天又贺了一次.

问题描述

给定一个长度为 \(n\) 的数列 \(\{a_i\}\), 求一个最短的齐次线性递推数列\(\{b_i\}\)(设长度为 \(m\)),使得对于所有 \(m \leq k \leq n\), 有 \(a_k = \sum_{i = 1} ^ m a_{k - i} b_i\)

复杂度要求: \(O(n ^ 2)\)

一看就很适合配合常系数齐次线性递推食用.

算法流程

增量构造.

假设我们当前已经求出了 \(a_{0...i - 1}\) 的线性递推数列. 计算过程中, 我们曾经得出过 \(c\) 个递推式, 第 \(i\) 个递推式在 \(fail_i\) 的位置第一次失效了.

一开始 \(c = 0\), 我们有一个空的递推式.

现在我们加入数 \(a_i\).

\(R_c\) 的长度为 \(m\), \(delta_i = a_i - \sum_{k = 1} ^ {m} a_{i - k} R_c(k)\)

如果 \(delta_i = 0\), \(R_c\) 仍是一个合法的递推式.

否则我们要对 \(R_c\) 做出调整, 来得到一个新的符合条件的递推式.

\(c = 0\), 那么前 \(i - 1\) 个数都是 \(0\). 我们只需要构造一个包含 \(i\)\(0\) 的递推式即可.

\(c \not= 0\), 只需要构造一个递推式 \(R'\), 当 \(|R'| + 1 \leq k < n\) 时, \(\sum_{i = 1} ^ {|R'|} a_{k - i} R'_i = 0\), \(\sum_{i = 1} ^ {|R'| } a_{n - i} R'_i = delta_n\), 那么 \(R_{c + 1} = R_c + R'\) 就符合条件.

我们随便找一个 $ 0 \leq id < c$, 它的前 \(fail_{id} - 1\) 个数都是 \(0\). 如果我们对它作一个位移, 即前面补上 \(i - fail_{id} - 1\)\(0\), 后面跟个 \(1\), 然后接上 \(-R_{id}\), 我们就可以得到一个只有位置 \(i\)\(delta_{fail_{id}}\),其余位置都是 \(0\) 的数组. 然后我们把它整个乘上 \(tmp = \frac{delta_{i}}{delta_{fail_{id}}}\), 就构造出了 \(R'\)

也就是 \(R'\)

\[\{0,0,...0,tmp,-tmp R_{id}(1),-tmp R_{id}(2),...\} \]

然后我们还要保证 \(R_c + R'\) 是最短的, 我们找到 \(i - fail_{id} + |R_{id}|\) 最小的即可. (不会严格证明)

模板

数据可以去周指导的博客上看.

#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define LL long long
#define pii pair<int, int>
using namespace std;
const int mod = 1e9 + 7;

template <typename T> T read(T &x) {
	int f = 0;
	register char c = getchar();
	while (c > '9' || c < '0') f |= (c == '-'), c = getchar();
	for (x = 0; c >= '0' && c <= '9'; c = getchar())
		x = (x << 3) + (x << 1) + (c ^ 48);
	if (f) x = -x;
	return x;
}

inline void upd(int &x, int y) {
	(x += y) >= mod ? x -= mod : 0;
}

inline int add(int x, int y) {
	return (x += y) >= mod ? x - mod : x;
}

inline int dec(int x, int y) {
	return (x -= y) < 0 ? x + mod : x;
}

inline int Qpow(int x, int p) {
	int ans = 1;
	for (; p; p >>= 1) {
		if (p & 1) ans = 1LL * ans * x % mod;
		x = 1LL * x * x % mod;
	}
	return ans;
}

inline int Inv(int x) {
	return Qpow(x, mod - 2);
}

namespace BM {

	const int Maxn = 5005;
	
	int n, c;
	int a[Maxn], del[Maxn], fail[Maxn];
	vector<int> R[Maxn];

	vector<int> solve() {
		c = 0;
		for (int i = 1; i <= n; ++i) {
			if (c == 0) {
				if (a[i]) {
					fail[0] = i;
					++c;
					del[i] = a[i];
					R[c].resize(i);
				}
				continue;
			}
			del[i] = a[i];
			for (int j = 0; j < R[c].size(); ++j) {
				del[i] = dec(del[i], 1LL * R[c][j] * a[i - j - 1] % mod);
			}
			if (del[i] == 0) continue;
			fail[c] = i;
			int id = c - 1, v = i - fail[id] + R[id].size();
			for (int j = c - 1; j >= 0; --j) {
				if (v > i - fail[j] + R[j].size()) {
					v = i - fail[j] + R[j].size();
					id = j;
				}
			}
			int p = i - fail[id];
			int tmp = 1LL * del[i] * Inv(del[fail[id]]) % mod;
			R[c + 1] = R[c];
			if (R[c + 1].size() < v) R[c + 1].resize(v);
			upd(R[c + 1][p - 1], tmp);
			for (int j = 0; j < R[id].size(); ++j) {
				upd(R[c + 1][p + j], -1LL * tmp * R[id][j] % mod + mod);
			}
			++c;
		}
		if (c == 0) return vector<int>(0);
		return R[c];
	}
}
using namespace BM;

int main() {
	read(n);
	for (int i = 1; i <= n; ++i) read(a[i]);
	vector<int> ans = BM::solve();
	cout << ans.size() << endl;
	for (int i = 0; i < ans.size(); ++i)
		cout << ans[i] << ' ';
	puts("");
}
posted @ 2019-11-12 21:21  Vexoben  阅读(131)  评论(1编辑  收藏  举报