square869120Contest #3 G Sum of Fibonacci Sequence
特判 \(n = 1\)。将 \(n, m\) 都减 \(1\),答案即为
\[[x^m]\frac{1}{(1 - x - x^2)(1 - x)^n}
\]
若能把这个分式拆成 \(\frac{A(x)}{(1 - x)^n} + \frac{B(x)}{1 - x - x^2}\) 的形式,其中 \(\deg A(x) \le n - 1, \deg B(x) \le 1\),那么答案就是好算的。
先考虑怎么求出一组合法的 \(A(x), B(x)\),满足 \(A(x)(1 - x - x^2) + B(x)(1 - x)^n = 1\)。因为 \(\deg B(x) \le 1\) 所以它比较好求,所以先求 \(B(x)\)。前面那个式子可以看成是对所有 \(x\) 都成立,那么我们代入 \(1 - x - x^2\) 的两个根 \(x_1 = \frac{-1 - \sqrt 5}{2}\) 和 \(x_2 = \frac{-1 + \sqrt 5}{2}\),得到:
\[\begin{cases}
B(x_1)(1 - x_1)^n = 1 \\
B(x_2)(1 - x_2)^n = 1
\end{cases}
\]
因为 \(\deg B(x) \le 1\) 所以这样可以直接解出 \(B(x)\)。
注意我们现在讨论的都是实数,实现时可以把每个数都用 \(a + b \sqrt 5\) 表示,封装一个结构体即可。
解出 \(B(x)\) 后可以解 \(A(x)\):
\[A(x) = \frac{1 - B(x)(1 - x)^n}{1 - x - x^2}
\]
因为能除尽,所以可以直接暴力大除法。
那么此时答案即为:
\[[x^m] \frac{A(x)}{(1 - x)^n} + [x^m] \frac{B(x)}{1 - x - x^2}
\]
先看左半部分:
\[[x^m] \frac{A(x)}{(1 - x)^n} = \sum\limits_{i \ge 0} [x^i] A(x) \times [x^{m - i}] \frac{1}{(1 - x)^n} = \sum\limits_{i \ge 0} [x^i] A(x) \times \binom{n + m - i - 1}{n - 1}
\]
组合数可以 \(O(n)\) 预处理前缀积和后缀积后 \(O(1)\) 计算。
再看右半部分(\(f_m\) 为斐波那契数列的第 \(m\) 项):
\[[x^m] \frac{B(x)}{1 - x - x^2} = [x^m] \frac{ax + b}{1 - x - x^2} = af_m + bf_{m + 1}
\]
\(f_n\) 可以直接套通项公式计算:
\[f_n = \frac{\sqrt 5}{5} (\frac{1 + \sqrt 5}{2})^n - \frac{\sqrt 5}{5} (\frac{1 - \sqrt 5}{2})^n
\]
那么这题就做完了。时间复杂度 \(O(n + \log m)\)。
code
// Problem: G - Sum of Fibonacci Sequence
// Contest: AtCoder - square869120Contest #3
// URL: https://atcoder.jp/contests/s8pc-3/tasks/s8pc_3_g
// Memory Limit: 256 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 200100;
const ll mod = 998244353;
const ll inv2 = (mod + 1) / 2;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
const ll inv5 = qpow(5, mod - 2);
ll n, m, fac[maxn], ifac[maxn], pre[maxn], suf[maxn];
inline ll C(ll n, ll m) {
if (n < m || n < 0 || m < 0) {
return 0;
} else {
return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
}
struct node {
ll x, y;
node(ll a = 0, ll b = 0) : x(a), y(b) {}
} a[9][9];
inline node operator + (const node &a, const node &b) {
return node((a.x + b.x) % mod, (a.y + b.y) % mod);
}
inline node operator - (const node &a, const node &b) {
return node((a.x - b.x + mod) % mod, (a.y - b.y + mod) % mod);
}
inline node operator * (const node &a, const node &b) {
return node((a.x * b.x + a.y * b.y % mod * 5) % mod, (a.x * b.y + a.y * b.x) % mod);
}
inline node operator / (const node &a, const node &b) {
ll inv = qpow((b.x * b.x - b.y * b.y % mod * 5 % mod + mod) % mod, mod - 2);
return node((a.x * b.x - a.y * b.y % mod * 5 % mod + mod) % mod * inv % mod, (a.y * b.x - a.x * b.y % mod + mod) % mod * inv % mod);
}
inline node qpow(node a, ll p) {
node res(1, 0);
while (p) {
if (p & 1) {
res = res * a;
}
a = a * a;
p >>= 1;
}
return res;
}
typedef vector<node> poly;
inline poly operator * (poly a, poly b) {
int n = (int)a.size() - 1, m = (int)b.size() - 1;
poly res(n + m + 1);
for (int i = 0; i <= n; ++i) {
for (int j = 0; j <= m; ++j) {
res[i + j] = res[i + j] + a[i] * b[j];
}
}
return res;
}
inline poly operator / (poly a, poly b) {
int n = (int)a.size() - 1, m = (int)b.size() - 1;
poly res(n - m + 1);
node I = 1 / b[m];
for (int i = n - m; ~i; --i) {
res[i] = a[i + m] * I;
for (int j = 0; j <= m; ++j) {
a[i + j] = a[i + j] - res[i] * b[j];
}
}
return res;
}
inline ll calc(ll n) {
node a(0, inv5), x(inv2, inv2), b(0, (mod - inv5) % mod), y(inv2, (mod - inv2) % mod);
node res = a * qpow(x, n) + b * qpow(y, n);
return res.x;
}
void solve() {
scanf("%lld%lld", &n, &m);
fac[0] = 1;
for (int i = 1; i <= n; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[n] = qpow(fac[n], mod - 2);
for (int i = n - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
if (n == 1) {
printf("%lld\n", calc(m));
return;
}
if (m == 1) {
puts("1");
return;
}
--m;
--n;
node x1(mod - inv2, mod - inv2), x2(mod - inv2, inv2);
node p = 1 / qpow(1 - x1, n), q = 1 / qpow(1 - x2, n);
node a = (p - q) / (x1 - x2);
node b = p - x1 * a;
poly B(2);
B[0] = b;
B[1] = a;
poly A(n + 1), F(3);
F[0] = 1;
F[1] = F[2] = mod - 1;
for (int i = 0; i <= n; ++i) {
A[i] = (i & 1) ? (mod - C(n, i)) % mod : C(n, i);
}
A = A * B;
for (node &x : A) {
x = 0 - x;
}
A[0] = A[0] + 1;
A = A / F;
node ans(0, 0);
pre[0] = (m + 1) % mod;
for (int i = 1; i <= n + 5; ++i) {
pre[i] = pre[i - 1] * ((m + i + 1) % mod) % mod;
}
suf[0] = m % mod;
for (int i = 1; i <= n + 5; ++i) {
suf[i] = suf[i - 1] * ((m - i + mod) % mod) % mod;
}
for (int i = 0; i <= min(n - 1, m); ++i) {
ll x = m - i, res = ifac[n - 1];
if (n + x - 1 - (m + 1) >= 0) {
res = res * pre[n + x - 1 - (m + 1)] % mod;
}
if (m - (x + 1) >= 0) {
res = res * suf[m - (x + 1)] % mod;
}
ans = ans + res * A[i];
}
ans = ans + a * calc(m) + b * calc(m + 1);
printf("%lld\n", ans.x);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}