题解-P4723 【模板】常系数齐次线性递推 [*hard]
正着推感觉很不可做,考虑倒着推。
例如我们要求 \(a_i = a_{i - 1} + 2 a_{i - 2}, a_0 = 2, a_1 = 1\) 的第 \(4\) 项的时候,我们倒着推 :
\[a_4 = a_3 + 2a_2 = 3a_2 + 2a_1 = 5a_1 + 6a_0 = 5 \times 1 + 6 \times 2 = 17
\]
我们发现,我们倒着推的过程相当于每次把最高的那一项 \(a_x\) 变成 \(\sum\limits_{i = 1}^{k} f_i a_{x - i}\) 。
很像一个取模的过程,如果一个数可以表示 \(a\) 数组中的几个数相加 : \(\sum\limits_{i = 0} p_i a_i\),那么我们把它表示成多项式 \(\sum\limits_{i = 0} p_i x^i\) 。
一次操作相当于是把 \(x^t\) 变成 \(\sum\limits_{i = 1}^{k} f_i x^{t - i}\) 。
构造多项式 \(\lambda\) 满足 \(\forall t, x^t \equiv \sum\limits_{i = 1}^{k} f_i x^{t - i} \pmod \lambda\)
\[x^t - \sum\limits_{i = 1}^{k} f_i x^{t - i} \equiv 0 \pmod \lambda
\]
\[x^{t - k} ( x^k - \sum\limits_{i = 1}^{k} f_i x^{k - i} ) \equiv 0 \pmod \lambda
\]
让 \(\lambda = x^k - \sum\limits_{i = 1}^{k} f_i x^{k - i}\) 即可满足。
因次 \(x^n\) 可以当作 \(x^n \bmod \lambda\) 来计算。
代码:
#include<bits/stdc++.h>
#define L(i, j, k) for(int i = j, i##E = k; i <= i##E; i++)
#define R(i, j, k) for(int i = j, i##E = k; i >= i##E; i--)
#define ll long long
#define ull unsigned long long
#define db double
#define pii pair<int, int>
#define mkp make_pair
using namespace std;
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(!isdigit(ch)) {
if(ch == '-') f = -1;
ch = getchar();
}
while(isdigit(ch)) x = x * 10 + (ch ^ 48), ch = getchar();
return x * f;
}
const int N = (1 << 18), mod = 998244353, G = 3, iG = (mod + 1) / G;
int qpow(int x, int y = mod - 2) {
int res = 1;
for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
return res;
}
int Lim, lim, pp[N], PowG[N], iPowG[N];
void revlim() { L(i, 0, lim - 1) pp[i] = ((pp[i >> 1] >> 1) | ((i & 1) * (lim >> 1))); }
void up(int x) { lim = 1; for(; lim <= x; lim <<= 1); }
void cle(int *f) { L(i, 0, lim - 1) f[i] = 0; }
void init(int x) {
int Pw;
up(x), Lim = lim;
Pw = qpow(G, (mod - 1) / Lim), PowG[0] = 1;
L(i, 1, lim - 1) PowG[i] = (ll) PowG[i - 1] * Pw % mod;
Pw = qpow(iG, (mod - 1) / Lim), iPowG[0] = 1;
L(i, 1, lim - 1) iPowG[i] = (ll) iPowG[i - 1] * Pw % mod;
}
inline void fmod(int &x) {
x += x >> 31 & mod;
}
inline void ad(int &x, int y) {
x += y, x -= mod; x += x >> 31 & mod;
}
inline int Sum(int x, int y) {
x += y, x -= mod; x += x >> 31 & mod;
return x;
}
void NTT(int *f, int flag) {
L(i, 0, lim - 1) if(pp[i] < i) swap(f[pp[i]], f[i]);
for(int i = 2; i <= lim; i <<= 1)
for(int j = 0, l = (i >> 1), ch = Lim / i; j < lim; j += i)
for(int k = j, now = 0; k < j + l; k ++) {
int pa = f[k], pb = (ll) f[k + l] * (flag == 1 ? PowG[now] : iPowG[now]) % mod;
f[k] = Sum(pa, pb), f[k + l] = Sum(pa, mod - pb), now += ch;
}
if(flag == -1) {
int nylim = qpow(lim);
L(i, 0, lim - 1) f[i] = (ll) f[i] * nylim % mod;
}
}
int sav[N];
void inv(int *f, int *g, int len) {
if(len == 1) return g[0] = qpow(f[0]), void();
inv(f, g, (len + 1) >> 1), up(len << 1), cle(sav), copy(f, f + len, sav), revlim(), NTT(sav, 1), NTT(g, 1);
L(i, 0, lim - 1) g[i] = (ll) g[i] * (2ll + mod - (ll) g[i] * sav[i] % mod) % mod;
NTT(g, -1), fill(g + len, g + lim, 0);
}
void Mul(int *f, int *g, int *ans, int n, int m) {
static int A[N], B[N];
up(n + m), revlim(), cle(A), cle(B), copy(f, f + n, A), copy(g, g + m, B);
NTT(A, 1), NTT(B, 1);
L(i, 0, lim - 1) A[i] = (ll) A[i] * B[i] % mod;
NTT(A, -1), copy(A, A + n + m - 1, ans);
}
void div(int *f, int *g, int *ansa, int *ansb, int n, int m) {
static int A[N];
reverse(f, f + n), reverse(g, g + m), up((n - m + 1) << 1), cle(A);
inv(g, A, n - m + 1), Mul(f, A, A, n - m + 1, n - m + 1);
reverse(A, A + n - m + 1), copy(A, A + n - m + 1, ansa);
reverse(f, f + n), reverse(g, g + m), Mul(A, g, A, n - m + 1, m);
L(i, 0, m - 2) ansb[i] = (f[i] - A[i] + mod) % mod;
}
int n, m, f[N], a[N], res[N], g[N], sv[N], vs[N], ans;
int main() {
n = read(), m = read(), init(m << 1);
f[m] = 1;
R(i, m - 1, 0) f[i] = ( mod - read() % mod ) % mod;
L(i, 0, m - 1) a[i] = read() % mod, fmod(a[i] += mod);
res[0] = 1, g[1] = 1;
for(; n; Mul(g, g, g, m, m), div(g, f, sv, vs, m << 1, m + 1), copy(vs, vs + m, g), n >>= 1)
if(n & 1) Mul(res, g, res, m, m), div(res, f, sv, vs, m << 1, m + 1), copy(vs, vs + m, res);
L(i, 0, m - 1) ad(ans, (ll) res[i] * a[i] % mod);
cout << ans << endl;
return 0;
}