square869120Contest #3 G Sum of Fibonacci Sequence

洛谷传送门

AtCoder 传送门

特判 \(n = 1\)。将 \(n, m\) 都减 \(1\),答案即为

\[[x^m]\frac{1}{(1 - x - x^2)(1 - x)^n} \]

若能把这个分式拆成 \(\frac{A(x)}{(1 - x)^n} + \frac{B(x)}{1 - x - x^2}\) 的形式,其中 \(\deg A(x) \le n - 1, \deg B(x) \le 1\),那么答案就是好算的。

先考虑怎么求出一组合法的 \(A(x), B(x)\),满足 \(A(x)(1 - x - x^2) + B(x)(1 - x)^n = 1\)。因为 \(\deg B(x) \le 1\) 所以它比较好求,所以先求 \(B(x)\)。前面那个式子可以看成是对所有 \(x\) 都成立,那么我们代入 \(1 - x - x^2\) 的两个根 \(x_1 = \frac{-1 - \sqrt 5}{2}\)\(x_2 = \frac{-1 + \sqrt 5}{2}\),得到:

\[\begin{cases} B(x_1)(1 - x_1)^n = 1 \\ B(x_2)(1 - x_2)^n = 1 \end{cases} \]

因为 \(\deg B(x) \le 1\) 所以这样可以直接解出 \(B(x)\)

注意我们现在讨论的都是实数,实现时可以把每个数都用 \(a + b \sqrt 5\) 表示,封装一个结构体即可。

解出 \(B(x)\) 后可以解 \(A(x)\)

\[A(x) = \frac{1 - B(x)(1 - x)^n}{1 - x - x^2} \]

因为能除尽,所以可以直接暴力大除法。

那么此时答案即为:

\[[x^m] \frac{A(x)}{(1 - x)^n} + [x^m] \frac{B(x)}{1 - x - x^2} \]

先看左半部分:

\[[x^m] \frac{A(x)}{(1 - x)^n} = \sum\limits_{i \ge 0} [x^i] A(x) \times [x^{m - i}] \frac{1}{(1 - x)^n} = \sum\limits_{i \ge 0} [x^i] A(x) \times \binom{n + m - i - 1}{n - 1} \]

组合数可以 \(O(n)\) 预处理前缀积和后缀积后 \(O(1)\) 计算。

再看右半部分(\(f_m\) 为斐波那契数列的第 \(m\) 项):

\[[x^m] \frac{B(x)}{1 - x - x^2} = [x^m] \frac{ax + b}{1 - x - x^2} = af_m + bf_{m + 1} \]

\(f_n\) 可以直接套通项公式计算:

\[f_n = \frac{\sqrt 5}{5} (\frac{1 + \sqrt 5}{2})^n - \frac{\sqrt 5}{5} (\frac{1 - \sqrt 5}{2})^n \]

那么这题就做完了。时间复杂度 \(O(n + \log m)\)

code
// Problem: G - Sum of Fibonacci Sequence
// Contest: AtCoder - square869120Contest #3
// URL: https://atcoder.jp/contests/s8pc-3/tasks/s8pc_3_g
// Memory Limit: 256 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;
const ll mod = 998244353;
const ll inv2 = (mod + 1) / 2;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

const ll inv5 = qpow(5, mod - 2);

ll n, m, fac[maxn], ifac[maxn], pre[maxn], suf[maxn];

inline ll C(ll n, ll m) {
	if (n < m || n < 0 || m < 0) {
		return 0;
	} else {
		return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
	}
}

struct node {
	ll x, y;
	node(ll a = 0, ll b = 0) : x(a), y(b) {}
} a[9][9];

inline node operator + (const node &a, const node &b) {
	return node((a.x + b.x) % mod, (a.y + b.y) % mod);
}

inline node operator - (const node &a, const node &b) {
	return node((a.x - b.x + mod) % mod, (a.y - b.y + mod) % mod);
}

inline node operator * (const node &a, const node &b) {
	return node((a.x * b.x + a.y * b.y % mod * 5) % mod, (a.x * b.y + a.y * b.x) % mod);
}

inline node operator / (const node &a, const node &b) {
	ll inv = qpow((b.x * b.x - b.y * b.y % mod * 5 % mod + mod) % mod, mod - 2);
	return node((a.x * b.x - a.y * b.y % mod * 5 % mod + mod) % mod * inv % mod, (a.y * b.x - a.x * b.y % mod + mod) % mod * inv % mod);
}

inline node qpow(node a, ll p) {
	node res(1, 0);
	while (p) {
		if (p & 1) {
			res = res * a;
		}
		a = a * a;
		p >>= 1;
	}
	return res;
}

typedef vector<node> poly;

inline poly operator * (poly a, poly b) {
	int n = (int)a.size() - 1, m = (int)b.size() - 1;
	poly res(n + m + 1);
	for (int i = 0; i <= n; ++i) {
		for (int j = 0; j <= m; ++j) {
			res[i + j] = res[i + j] + a[i] * b[j];
		}
	}
	return res;
}

inline poly operator / (poly a, poly b) {
	int n = (int)a.size() - 1, m = (int)b.size() - 1;
	poly res(n - m + 1);
	node I = 1 / b[m];
	for (int i = n - m; ~i; --i) {
		res[i] = a[i + m] * I;
		for (int j = 0; j <= m; ++j) {
			a[i + j] = a[i + j] - res[i] * b[j];
		}
	}
	return res;
}

inline ll calc(ll n) {
	node a(0, inv5), x(inv2, inv2), b(0, (mod - inv5) % mod), y(inv2, (mod - inv2) % mod);
	node res = a * qpow(x, n) + b * qpow(y, n);
	return res.x;
}

void solve() {
	scanf("%lld%lld", &n, &m);
	fac[0] = 1;
	for (int i = 1; i <= n; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[n] = qpow(fac[n], mod - 2);
	for (int i = n - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
	if (n == 1) {
		printf("%lld\n", calc(m));
		return;
	}
	if (m == 1) {
		puts("1");
		return;
	}
	--m;
	--n;
	node x1(mod - inv2, mod - inv2), x2(mod - inv2, inv2);
	node p = 1 / qpow(1 - x1, n), q = 1 / qpow(1 - x2, n);
	node a = (p - q) / (x1 - x2);
	node b = p - x1 * a;
	poly B(2);
	B[0] = b;
	B[1] = a;
	poly A(n + 1), F(3);
	F[0] = 1;
	F[1] = F[2] = mod - 1;
	for (int i = 0; i <= n; ++i) {
		A[i] = (i & 1) ? (mod - C(n, i)) % mod : C(n, i);
	}
	A = A * B;
	for (node &x : A) {
		x = 0 - x;
	}
	A[0] = A[0] + 1;
	A = A / F;
	node ans(0, 0);
	pre[0] = (m + 1) % mod;
	for (int i = 1; i <= n + 5; ++i) {
		pre[i] = pre[i - 1] * ((m + i + 1) % mod) % mod;
	}
	suf[0] = m % mod;
	for (int i = 1; i <= n + 5; ++i) {
		suf[i] = suf[i - 1] * ((m - i + mod) % mod) % mod;
	}
	for (int i = 0; i <= min(n - 1, m); ++i) {
		ll x = m - i, res = ifac[n - 1];
		if (n + x - 1 - (m + 1) >= 0) {
			res = res * pre[n + x - 1 - (m + 1)] % mod;
		}
		if (m - (x + 1) >= 0) {
			res = res * suf[m - (x + 1)] % mod;
		}
		ans = ans + res * A[i];
	}
	ans = ans + a * calc(m) + b * calc(m + 1);
	printf("%lld\n", ans.x);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}

posted @ 2024-06-15 14:24  zltzlt  阅读(13)  评论(0编辑  收藏  举报