对于所有长度为 \(n\) 且总和为 \(m\) 的任意正整数序列 \(a\),求 \(\sum\prod a_i \bmod~ 998244353\)

限制:

  • \(1 \leqslant n, m \leqslant 2 \times 10^5\)

算法分析

做法一:积的和典型

一方面,满足 \(\sum a_i = m\) 的正整数序列个数,可以考虑在 \(m\) 个白球中插入 \(n-1\) 个隔板,那么就会得到 \(n\) 个区间

另一方面,\(\prod a_i\) 可以理解为在每个区间中任选一个球(不妨将这个球染成红色)的方案数,因为不同区间的选法是独立的,所以是乘法

然后将这二者综合起来考虑,不妨把隔板换成红球(不会影响上面的选法数),这样就得到 \(n+m-1\) 个球,且其中有 \(2n-1\) 个红球,那么总方案数就是 \(\binom{n+m-1}{2n-1}\)

(更喜欢这种做法,虽然理解起来会比较困难)

做法二:生成函数

写成生成函数的式子就是 \([x^m](x + 2x^2 + 3x^3 + \cdots)^n\)

\(f(x) = x + 2x^2 + 3x^3 + \cdots\)
\( \begin{aligned} (1-x)f(x) &= (x + 2x^2 + 3x^3 + \cdots) - (x^2 + 2x^3 + 3x^4 + \cdots)\\ &= x + x^2 + x^3 + \cdots\\ &= \frac{x}{1-x} \end{aligned} \)

\( \Rightarrow ~ f(x) = \frac{x}{(1-x)^2} \)

于是,\([x^m]f(x)^n = [x^m]\frac{x^n}{(1-x)^{2n}} = [x^{m-n}]\frac{1}{(1-x)^{2n}}\)

这个值可以用负数的二项式写成 \((-1)^{m-n}\binom{-2n}{m-n}\)

一般地,我们有 \(\binom{-n}{k} = (-1)^k\binom{n+k-1}{k}\)

所以,\((-1)^{m-n}\binom{-2n}{m-n} = \binom{m+n-1}{m-n} = \binom{m+n-1}{2n-1}\)

代码实现
#include <bits/stdc++.h>

using namespace std;
using ll = long long;

const int mod = 998244353;
//const int mod = 1000000007;
struct mint {
    ll x;
    mint(ll x=0):x((x%mod+mod)%mod) {}
    mint operator-() const {
        return mint(-x);
    }
    mint& operator+=(const mint a) {
        if ((x += a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint a) {
        if ((x += mod-a.x) >= mod) x -= mod;
        return *this;
    }
    mint& operator*=(const mint a) {
        (x *= a.x) %= mod;
        return *this;
    }
    mint operator+(const mint a) const {
        return mint(*this) += a;
    }
    mint operator-(const mint a) const {
        return mint(*this) -= a;
    }
    mint operator*(const mint a) const {
        return mint(*this) *= a;
    }
    mint pow(ll t) const {
        if (!t) return 1;
        mint a = pow(t>>1);
        a *= a;
        if (t&1) a *= *this;
        return a;
    }

    // for prime mod
    mint inv() const {
        return pow(mod-2);
    }
    mint& operator/=(const mint a) {
        return *this *= a.inv();
    }
    mint operator/(const mint a) const {
        return mint(*this) /= a;
    }
};
istream& operator>>(istream& is, mint& a) {
    return is >> a.x;
}
ostream& operator<<(ostream& os, const mint& a) {
    return os << a.x;
}

struct modinv {
  int n; vector<mint> d;
  modinv(): n(2), d({0,1}) {}
  mint operator()(int i) {
    while (n <= i) d.push_back(-d[mod%n]*(mod/n)), ++n;
    return d[i];
  }
  mint operator[](int i) const { return d[i];}
} invs;
struct modfact {
  int n; vector<mint> d;
  modfact(): n(2), d({1,1}) {}
  mint operator()(int i) {
    while (n <= i) d.push_back(d.back()*n), ++n;
    return d[i];
  }
  mint operator[](int i) const { return d[i];}
} facts;
struct modfactinv {
  int n; vector<mint> d;
  modfactinv(): n(2), d({1,1}) {}
  mint operator()(int i) {
    while (n <= i) d.push_back(d.back()*invs(n)), ++n;
    return d[i];
  }
  mint operator[](int i) const { return d[i];}
} ifacts;
mint comb(int n, int k) {
  if (n < k || k < 0) return 0;
  return facts(n)*ifacts(k)*ifacts(n-k);
}

int main() {
    int n, m;
    cin >> n >> m;
    
    mint ans = comb(m+n-1, 2*n-1);
    cout << ans << '\n';
    
    return 0;
}