[洛谷P4723]【模板】线性递推

题目大意:求一个满足$k$阶齐次线性递推数列$a_i$的第$n$项。

即:$a_n=\sum\limits_{i=1}^{k}f_i \times a_{n-i}$

解:线性齐次递推,先见洛谷题解,下回再补

卡点:数组大小计算错误,求逆中途计算时忘记加$mod$等

 

C++ Code:(这份全部是板子,可以用来测试,但是常数巨大)

#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#define maxk 32010
#define maxn 131072
const int mod = 998244353;

#define mul(x, y) static_cast<long long> (x) * (y) % mod

namespace Math {
	inline int pw(int base, int p) {
		static int res;
		for (res = 1; p; p >>= 1, base = mul(base, base)) if (p & 1) res = mul(res, base);
		return res;
	}
	inline int inv(int x) { return pw(x, mod - 2); }
}
inline void reduce(int &x) { x += x >> 31 & mod; }

namespace Poly {
#define N maxn
	int lim, s, rev[N], Wn[N];
	inline void init(const int n) {
		lim = 1, s = -1; while (lim < n) lim <<= 1, ++s;
		for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
		const int t = Math::pw(3, (mod - 1) / lim);
		*Wn = 1; for (register int *i = Wn + 1; i != Wn + lim; ++i) *i = mul(*(i - 1), t);
	}
	inline void FFT(int *A, const int op = 1) {
		for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
		for (register int mid = 1; mid < lim; mid <<= 1) {
			const int t = lim / mid >> 1;
			for (register int i = 0; i < lim; i += mid << 1)
				for (register int j = 0; j < mid; ++j) {
					const int X = A[i + j], Y = mul(A[i + j + mid], Wn[t * j]);
					reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y);
				}
		}
		if (!op) {
			const int ilim = Math::inv(lim);
			for (register int *i = A; i != A + lim; ++i) *i = mul(*i, ilim);
			std::reverse(A + 1, A + lim);
		}
	}

	void INV(int *A, int *B, int n) {
		if (n == 1) { *B = Math::inv(*A); return ; }
		static int C[N], D[N];
		const int len = n + 1 >> 1;
		INV(A, B, len), init(len * 3);
		std::memcpy(C, A, n << 2), std::memset(C + n, 0, lim - n << 2);
		std::memcpy(D, B, len << 2), std::memset(D + len, 0, lim - len << 2);
		FFT(C), FFT(D);
		for (int i = 0; i < lim; ++i) D[i] = (2 - mul(D[i], C[i]) + mod) * D[i] % mod;
		FFT(D, 0);
		std::memcpy(B + len, D + len, n - len << 2);
	}
	void DIV(int *A, int *B, int *Q, int n, int m) {
		static int C[N], D[N], E[N];
		const int len = n - m + 1;
		std::reverse_copy(A, A + n, C), std::reverse_copy(B, B + m, D);
		INV(D, E, len), init(len << 1);
		std::memset(C + len, 0, lim - len << 2), std::memset(E + len, 0, lim - len << 2);
		FFT(C), FFT(E);
		for (int i = 0; i < lim; ++i) Q[i] = mul(C[i], E[i]);
		FFT(Q, 0), std::reverse(Q, Q + len);
	}
	void DIV_MOD(int *A, int *B, int *Q, int *R, int n, int m) {
		static int C[N], D[N], E[N];
		const int len = n - m + 1;
		DIV(A, B, Q, n, m), init(n << 1);
		std::memcpy(C, A, n << 2), std::memset(C + n, 0, lim - n << 2);
		std::memcpy(D, B, m << 2), std::memset(D + m, 0, lim - m << 2);
		std::memcpy(E, Q, len << 2), std::memset(E + len, 0, lim - len << 2);
		FFT(C), FFT(D), FFT(E);
		for (int i = 0; i < lim; ++i) reduce(R[i] = C[i] - mul(D[i], E[i]));
		FFT(R, 0);
	}
	void MOD(int *A, int *B, int m) {
		static int Q[N], R[N];
		DIV_MOD(A, B, Q, R, (m << 1) - 1, m + 1);
		std::memcpy(A, R, m << 2);
	}

	void POW(int *base, int p, int *Mod, int m) {
		static int res[N], T[N];
		res[0] = 1;
		while (p) {
			if (p & 1) {
				init(m << 1), std::memset(res + m, 0, lim - m << 2);
				std::memcpy(T, base, m << 2), std::memset(T + m, 0, lim - m << 2);
				FFT(T), FFT(res);
				for (int i = 0; i < lim; ++i) res[i] = mul(res[i], T[i]);
				FFT(res, 0); MOD(res, Mod, m);
			}
			p >>= 1;
			if (p) {
				init(m << 1), std::memset(base + m, 0, lim - m << 2);
				FFT(base);
				for (int i = 0; i < lim; ++i) base[i] = mul(base[i], base[i]);
				FFT(base, 0), MOD(base, Mod, m);
			}
		}
		std::memcpy(base, res, m << 2);
	}

	int solve(int *f, int *a, int n, int k) { //a为递推式0~k-1项,f为转移数组1~k项
		static int A[maxn], G[maxn];
		for (int i = 1; i <= k; ++i) reduce(G[k - i] = -f[i]);
		G[k] = A[1] = 1;
		Poly::POW(A, n, G, k);
		int ans = 0;
		for (int i = 0; i < k; ++i) reduce(ans += mul(A[i], a[i]) - mod);
		return ans;
	}
#undef N
}

int n, k;
int f[maxk], a[maxk];
int main() 
	std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
	std::cin >> n >> k;
	for (int i = 1; i <= k; ++i) std::cin >> f[i];
	for (int i = 0; i < k; ++i) std::cin >> a[i], reduce(a[i]);
	std::cout << Poly::solve(f, a, n, k) << '\n';
	return 0;
}

 

发现取模的那一个多项式是一定的,可以预处理出它的逆元以及点值表达式等,减小常数。

C++ Code:(这一份常数还算正常)

#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#define maxk 32010
#define maxn 65536
const int mod = 998244353;

#define mul(x, y) static_cast<long long> (x) * (y) % mod

namespace Math {
	inline int pw(int base, int p) {
		static int res;
		for (res = 1; p; p >>= 1, base = mul(base, base)) if (p & 1) res = mul(res, base);
		return res;
	}
	inline int inv(int x) { return pw(x, mod - 2); }
}
inline void reduce(int &x) { x += x >> 31 & mod; }

namespace Poly {
#define N maxn
	int lim, s, rev[N], Wn[N];
	inline void init(const int n) {
		lim = 1, s = -1; while (lim < n) lim <<= 1, ++s;
		for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
		const int t = Math::pw(3, (mod - 1) / lim);
		*Wn = 1; for (register int *i = Wn + 1; i != Wn + lim; ++i) *i = mul(*(i - 1), t);
	}
	inline void FFT(int *A, const int op = 1) {
		for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
		for (register int mid = 1; mid < lim; mid <<= 1) {
			const int t = lim / mid >> 1;
			for (register int i = 0; i < lim; i += mid << 1)
				for (register int j = 0; j < mid; ++j) {
					const int X = A[i + j], Y = mul(A[i + j + mid], Wn[t * j]);
					reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y);
				}
		}
		if (!op) {
			const int ilim = Math::inv(lim);
			for (register int *i = A; i != A + lim; ++i) *i = mul(*i, ilim);
			std::reverse(A + 1, A + lim);
		}
	}

	void INV(int *A, int *B, int n) {
		if (n == 1) { *B = Math::inv(*A); return ; }
		static int C[N], D[N];
		const int len = n + 1 >> 1;
		INV(A, B, len), init(len * 3);
		std::memcpy(C, A, n << 2), std::memset(C + n, 0, lim - n << 2);
		std::memcpy(D, B, len << 2), std::memset(D + len, 0, lim - len << 2);
		FFT(C), FFT(D);
		for (int i = 0; i < lim; ++i) D[i] = (2 - mul(D[i], C[i]) + mod) * D[i] % mod;
		FFT(D, 0);
		std::memcpy(B + len, D + len, n - len << 2);
	}

	int G[N], INVG[N];
	void DIV(int *A, int *Q, int n, int m) {
		static int C[N];
		const int len = n - m + 1;
		std::reverse_copy(A, A + n, C), std::memset(C + len, 0, lim - len << 2);
		FFT(C);
		for (int i = 0; i < lim; ++i) Q[i] = mul(C[i], INVG[i]);
		FFT(Q, 0), std::reverse(Q, Q + len);
	}
	void DIV_MOD(int *A, int *R, int n, int m) {
		static int Q[N];
		const int len = n - m + 1;
		DIV(A, Q, n, m), std::memset(Q + len, 0, lim - len << 2);
		FFT(Q);
		for (int i = 0; i < lim; ++i) R[i] = mul(G[i], Q[i]);
		FFT(R, 0);
		for (int i = 0; i < m; ++i) reduce(R[i] = A[i] - R[i]);
	}

	void POW(int *A, int p, int m) {
		if (!p) return ;
		POW(A, p >> 1, m);
		static int T[N];
		std::memcpy(T, A, m << 2), std::memset(T + m, 0, lim - m << 2);
		FFT(T);
		for (int i = 0; i < lim; ++i) T[i] = mul(T[i], T[i]);
		FFT(T, 0);
		if (p & 1) {
			for (int i = 2 * m - 1; ~i; --i) T[i] = T[i - 1];
			T[0] = 0;
		}
		DIV_MOD(T, A, 2 * m, m + 1);
	}

	int solve(int *f, int *a, int n, int k) { //a为递推式0~k-1项,f为转移数组1~k项
		static int A[maxn], B[maxn];
		for (int i = 1; i <= k; ++i) reduce(G[k - i] = -f[i]);
		G[k] = A[0] = 1;
		std::reverse_copy(G, G + k + 1, B), B[k] = 0;
		INV(B, INVG, k), init(k << 1);
		FFT(G), FFT(INVG);
		Poly::POW(A, n, k);
		int ans = 0;
		for (int i = 0; i < k; ++i) reduce(ans += mul(A[i], a[i]) - mod);
		return ans;
	}
#undef N
}

int n, k;
int f[maxk], a[maxk];
int main() {
	std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
	std::cin >> n >> k;
	for (int i = 1; i <= k; ++i) std::cin >> f[i];
	for (int i = 0; i < k; ++i) std::cin >> a[i], reduce(a[i]);
	std::cout << Poly::solve(f, a, n, k) << '\n';
	return 0;
}

  

 

posted @ 2019-02-16 21:20  Memory_of_winter  阅读(186)  评论(0编辑  收藏  举报