[洛谷P4723]【模板】线性递推
题目大意:求一个满足$k$阶齐次线性递推数列$a_i$的第$n$项。
即:$a_n=\sum\limits_{i=1}^{k}f_i \times a_{n-i}$
题解:线性齐次递推,先见洛谷题解,下回再补
卡点:数组大小计算错误,求逆中途计算时忘记加$mod$等
C++ Code:(这份全部是板子,可以用来测试,但是常数巨大)
#include <algorithm> #include <cstdio> #include <cstdlib> #include <cstring> #include <iostream> #define maxk 32010 #define maxn 131072 const int mod = 998244353; #define mul(x, y) static_cast<long long> (x) * (y) % mod namespace Math { inline int pw(int base, int p) { static int res; for (res = 1; p; p >>= 1, base = mul(base, base)) if (p & 1) res = mul(res, base); return res; } inline int inv(int x) { return pw(x, mod - 2); } } inline void reduce(int &x) { x += x >> 31 & mod; } namespace Poly { #define N maxn int lim, s, rev[N], Wn[N]; inline void init(const int n) { lim = 1, s = -1; while (lim < n) lim <<= 1, ++s; for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s; const int t = Math::pw(3, (mod - 1) / lim); *Wn = 1; for (register int *i = Wn + 1; i != Wn + lim; ++i) *i = mul(*(i - 1), t); } inline void FFT(int *A, const int op = 1) { for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]); for (register int mid = 1; mid < lim; mid <<= 1) { const int t = lim / mid >> 1; for (register int i = 0; i < lim; i += mid << 1) for (register int j = 0; j < mid; ++j) { const int X = A[i + j], Y = mul(A[i + j + mid], Wn[t * j]); reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y); } } if (!op) { const int ilim = Math::inv(lim); for (register int *i = A; i != A + lim; ++i) *i = mul(*i, ilim); std::reverse(A + 1, A + lim); } } void INV(int *A, int *B, int n) { if (n == 1) { *B = Math::inv(*A); return ; } static int C[N], D[N]; const int len = n + 1 >> 1; INV(A, B, len), init(len * 3); std::memcpy(C, A, n << 2), std::memset(C + n, 0, lim - n << 2); std::memcpy(D, B, len << 2), std::memset(D + len, 0, lim - len << 2); FFT(C), FFT(D); for (int i = 0; i < lim; ++i) D[i] = (2 - mul(D[i], C[i]) + mod) * D[i] % mod; FFT(D, 0); std::memcpy(B + len, D + len, n - len << 2); } void DIV(int *A, int *B, int *Q, int n, int m) { static int C[N], D[N], E[N]; const int len = n - m + 1; std::reverse_copy(A, A + n, C), std::reverse_copy(B, B + m, D); INV(D, E, len), init(len << 1); std::memset(C + len, 0, lim - len << 2), std::memset(E + len, 0, lim - len << 2); FFT(C), FFT(E); for (int i = 0; i < lim; ++i) Q[i] = mul(C[i], E[i]); FFT(Q, 0), std::reverse(Q, Q + len); } void DIV_MOD(int *A, int *B, int *Q, int *R, int n, int m) { static int C[N], D[N], E[N]; const int len = n - m + 1; DIV(A, B, Q, n, m), init(n << 1); std::memcpy(C, A, n << 2), std::memset(C + n, 0, lim - n << 2); std::memcpy(D, B, m << 2), std::memset(D + m, 0, lim - m << 2); std::memcpy(E, Q, len << 2), std::memset(E + len, 0, lim - len << 2); FFT(C), FFT(D), FFT(E); for (int i = 0; i < lim; ++i) reduce(R[i] = C[i] - mul(D[i], E[i])); FFT(R, 0); } void MOD(int *A, int *B, int m) { static int Q[N], R[N]; DIV_MOD(A, B, Q, R, (m << 1) - 1, m + 1); std::memcpy(A, R, m << 2); } void POW(int *base, int p, int *Mod, int m) { static int res[N], T[N]; res[0] = 1; while (p) { if (p & 1) { init(m << 1), std::memset(res + m, 0, lim - m << 2); std::memcpy(T, base, m << 2), std::memset(T + m, 0, lim - m << 2); FFT(T), FFT(res); for (int i = 0; i < lim; ++i) res[i] = mul(res[i], T[i]); FFT(res, 0); MOD(res, Mod, m); } p >>= 1; if (p) { init(m << 1), std::memset(base + m, 0, lim - m << 2); FFT(base); for (int i = 0; i < lim; ++i) base[i] = mul(base[i], base[i]); FFT(base, 0), MOD(base, Mod, m); } } std::memcpy(base, res, m << 2); } int solve(int *f, int *a, int n, int k) { //a为递推式0~k-1项,f为转移数组1~k项 static int A[maxn], G[maxn]; for (int i = 1; i <= k; ++i) reduce(G[k - i] = -f[i]); G[k] = A[1] = 1; Poly::POW(A, n, G, k); int ans = 0; for (int i = 0; i < k; ++i) reduce(ans += mul(A[i], a[i]) - mod); return ans; } #undef N } int n, k; int f[maxk], a[maxk]; int main() std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0); std::cin >> n >> k; for (int i = 1; i <= k; ++i) std::cin >> f[i]; for (int i = 0; i < k; ++i) std::cin >> a[i], reduce(a[i]); std::cout << Poly::solve(f, a, n, k) << '\n'; return 0; }
发现取模的那一个多项式是一定的,可以预处理出它的逆元以及点值表达式等,减小常数。
C++ Code:(这一份常数还算正常)
#include <algorithm> #include <cstdio> #include <cstdlib> #include <cstring> #include <iostream> #define maxk 32010 #define maxn 65536 const int mod = 998244353; #define mul(x, y) static_cast<long long> (x) * (y) % mod namespace Math { inline int pw(int base, int p) { static int res; for (res = 1; p; p >>= 1, base = mul(base, base)) if (p & 1) res = mul(res, base); return res; } inline int inv(int x) { return pw(x, mod - 2); } } inline void reduce(int &x) { x += x >> 31 & mod; } namespace Poly { #define N maxn int lim, s, rev[N], Wn[N]; inline void init(const int n) { lim = 1, s = -1; while (lim < n) lim <<= 1, ++s; for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s; const int t = Math::pw(3, (mod - 1) / lim); *Wn = 1; for (register int *i = Wn + 1; i != Wn + lim; ++i) *i = mul(*(i - 1), t); } inline void FFT(int *A, const int op = 1) { for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]); for (register int mid = 1; mid < lim; mid <<= 1) { const int t = lim / mid >> 1; for (register int i = 0; i < lim; i += mid << 1) for (register int j = 0; j < mid; ++j) { const int X = A[i + j], Y = mul(A[i + j + mid], Wn[t * j]); reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y); } } if (!op) { const int ilim = Math::inv(lim); for (register int *i = A; i != A + lim; ++i) *i = mul(*i, ilim); std::reverse(A + 1, A + lim); } } void INV(int *A, int *B, int n) { if (n == 1) { *B = Math::inv(*A); return ; } static int C[N], D[N]; const int len = n + 1 >> 1; INV(A, B, len), init(len * 3); std::memcpy(C, A, n << 2), std::memset(C + n, 0, lim - n << 2); std::memcpy(D, B, len << 2), std::memset(D + len, 0, lim - len << 2); FFT(C), FFT(D); for (int i = 0; i < lim; ++i) D[i] = (2 - mul(D[i], C[i]) + mod) * D[i] % mod; FFT(D, 0); std::memcpy(B + len, D + len, n - len << 2); } int G[N], INVG[N]; void DIV(int *A, int *Q, int n, int m) { static int C[N]; const int len = n - m + 1; std::reverse_copy(A, A + n, C), std::memset(C + len, 0, lim - len << 2); FFT(C); for (int i = 0; i < lim; ++i) Q[i] = mul(C[i], INVG[i]); FFT(Q, 0), std::reverse(Q, Q + len); } void DIV_MOD(int *A, int *R, int n, int m) { static int Q[N]; const int len = n - m + 1; DIV(A, Q, n, m), std::memset(Q + len, 0, lim - len << 2); FFT(Q); for (int i = 0; i < lim; ++i) R[i] = mul(G[i], Q[i]); FFT(R, 0); for (int i = 0; i < m; ++i) reduce(R[i] = A[i] - R[i]); } void POW(int *A, int p, int m) { if (!p) return ; POW(A, p >> 1, m); static int T[N]; std::memcpy(T, A, m << 2), std::memset(T + m, 0, lim - m << 2); FFT(T); for (int i = 0; i < lim; ++i) T[i] = mul(T[i], T[i]); FFT(T, 0); if (p & 1) { for (int i = 2 * m - 1; ~i; --i) T[i] = T[i - 1]; T[0] = 0; } DIV_MOD(T, A, 2 * m, m + 1); } int solve(int *f, int *a, int n, int k) { //a为递推式0~k-1项,f为转移数组1~k项 static int A[maxn], B[maxn]; for (int i = 1; i <= k; ++i) reduce(G[k - i] = -f[i]); G[k] = A[0] = 1; std::reverse_copy(G, G + k + 1, B), B[k] = 0; INV(B, INVG, k), init(k << 1); FFT(G), FFT(INVG); Poly::POW(A, n, k); int ans = 0; for (int i = 0; i < k; ++i) reduce(ans += mul(A[i], a[i]) - mod); return ans; } #undef N } int n, k; int f[maxk], a[maxk]; int main() { std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0); std::cin >> n >> k; for (int i = 1; i <= k; ++i) std::cin >> f[i]; for (int i = 0; i < k; ++i) std::cin >> a[i], reduce(a[i]); std::cout << Poly::solve(f, a, n, k) << '\n'; return 0; }