题目

满足以下条件的长为 \(n\) 的整数序列 \(a_1, a_2, \cdots, a_n\) 有多少个?

  • \(1 \leqslant a_i \leqslant m\)\(1 \leqslant i \leqslant n\)
  • \(|a_i-a_{i+1}| \geqslant k\ (1 \leqslant i \leqslant n-1)\)

输出答案除以 \(998244353\) 的余数。

限制:

  • \(2 \leqslant n \leqslant 1000\)
  • \(1 \leqslant m \leqslant 5000\)
  • \(0 \leqslant k \leqslant m-1\)

算法分析

dp[i][j] 表示长为 \(i\) 的结尾等于 \(j\) 的数列个数

转移方程:

\( dp[i][j] = (dp[i-1][1] + \cdots + dp[i-1][j-k]) + (dp[i-1][j+k] + \cdots + dp[i-1][m]) \)

然后用前缀和去优化
s[i][j] 表示长为 \(i\) 且结尾 \(\leqslant j\) 的数列个数

初始值:

\(dp[1][1 \sim m] = 1\)
\(s[1][j] = j\)

还需特判一下 \(k = 0\) 的情况,答案是 \(m^n\)

代码实现
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)

using namespace std;
using ll = long long;

const int mod = 998244353;
//const int mod = 1000000007;
struct mint {
    ll x;
    mint(ll x=0):x((x%mod+mod)%mod) {}
    mint operator-() const {
        return mint(-x);
    }
    mint& operator+=(const mint a) {
        if ((x += a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint a) {
        if ((x += mod-a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator*=(const mint a) {
        (x *= a.x) %= mod;
        return *this;
    }
    mint operator+(const mint a) const {
        return mint(*this) += a;
    }
    mint operator-(const mint a) const {
        return mint(*this) -= a;
    }
    mint operator*(const mint a) const {
        return mint(*this) *= a;
    }
    mint pow(ll t) const {
        if (!t) return 1;
        mint a = pow(t>>1);
        a *= a;
        if (t&1) a *= *this;
        return a;
    }

    // for prime mod
    mint inv() const {
        return pow(mod-2);
    }
    mint& operator/=(const mint a) {
        return *this *= a.inv();
    }
    mint operator/(const mint a) const {
        return mint(*this) /= a;
    }
};
istream& operator>>(istream& is, mint& a) {
    return is >> a.x;
}
ostream& operator<<(ostream& os, const mint& a) {
    return os << a.x;
}

mint dp[1005][5005];
mint s[1005][5005];

int main() {
    int n, m, k;
    cin >> n >> m >> k;
    
    if (k == 0) {
        cout << mint(m).pow(n) << '\n';
        return 0;
    }
    
    rep(j, m) dp[1][j] = 1;
    rep(j, m) s[1][j] = j;
    for (int i = 2; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            mint now;
            if (j-k >= 0) now += s[i-1][j-k];
            if (j+k-1 <= m) now += s[i-1][m] - s[i-1][j+k-1];
            dp[i][j] = now;
            s[i][j] = s[i][j-1] + now;
        }
    }
    
    mint ans;
    rep(j, m) {
        ans += dp[n][j];
    }
    
    cout << ans << '\n';
    
    return 0;
}