对于所有长度为 \(n\) 且总和为 \(m\) 的任意正整数序列 \(a\),求 \(\sum\prod a_i \bmod~ 998244353\) 。
限制:
- \(1 \leqslant n, m \leqslant 2 \times 10^5\)
算法分析
做法一:积的和典型
一方面,满足 \(\sum a_i = m\) 的正整数序列个数,可以考虑在 \(m\) 个白球中插入 \(n-1\) 个隔板,那么就会得到 \(n\) 个区间
另一方面,\(\prod a_i\) 可以理解为在每个区间中任选一个球(不妨将这个球染成红色)的方案数,因为不同区间的选法是独立的,所以是乘法
然后将这二者综合起来考虑,不妨把隔板换成红球(不会影响上面的选法数),这样就得到 \(n+m-1\) 个球,且其中有 \(2n-1\) 个红球,那么总方案数就是 \(\binom{n+m-1}{2n-1}\)
(更喜欢这种做法,虽然理解起来会比较困难)
做法二:生成函数
写成生成函数的式子就是 \([x^m](x + 2x^2 + 3x^3 + \cdots)^n\)
记 \(f(x) = x + 2x^2 + 3x^3 + \cdots\)
\(
\begin{aligned}
(1-x)f(x) &= (x + 2x^2 + 3x^3 + \cdots) - (x^2 + 2x^3 + 3x^4 + \cdots)\\
&= x + x^2 + x^3 + \cdots\\
&= \frac{x}{1-x}
\end{aligned}
\)
\( \Rightarrow ~ f(x) = \frac{x}{(1-x)^2} \)
于是,\([x^m]f(x)^n = [x^m]\frac{x^n}{(1-x)^{2n}} = [x^{m-n}]\frac{1}{(1-x)^{2n}}\)
这个值可以用负数的二项式写成 \((-1)^{m-n}\binom{-2n}{m-n}\)
一般地,我们有 \(\binom{-n}{k} = (-1)^k\binom{n+k-1}{k}\)
所以,\((-1)^{m-n}\binom{-2n}{m-n} = \binom{m+n-1}{m-n} = \binom{m+n-1}{2n-1}\) 。
代码实现
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int mod = 998244353;
//const int mod = 1000000007;
struct mint {
ll x;
mint(ll x=0):x((x%mod+mod)%mod) {}
mint operator-() const {
return mint(-x);
}
mint& operator+=(const mint a) {
if ((x += a.x) >= mod) x -= mod;
return *this;
}
mint& operator-=(const mint a) {
if ((x += mod-a.x) >= mod) x -= mod;
return *this;
}
mint& operator*=(const mint a) {
(x *= a.x) %= mod;
return *this;
}
mint operator+(const mint a) const {
return mint(*this) += a;
}
mint operator-(const mint a) const {
return mint(*this) -= a;
}
mint operator*(const mint a) const {
return mint(*this) *= a;
}
mint pow(ll t) const {
if (!t) return 1;
mint a = pow(t>>1);
a *= a;
if (t&1) a *= *this;
return a;
}
// for prime mod
mint inv() const {
return pow(mod-2);
}
mint& operator/=(const mint a) {
return *this *= a.inv();
}
mint operator/(const mint a) const {
return mint(*this) /= a;
}
};
istream& operator>>(istream& is, mint& a) {
return is >> a.x;
}
ostream& operator<<(ostream& os, const mint& a) {
return os << a.x;
}
struct modinv {
int n; vector<mint> d;
modinv(): n(2), d({0,1}) {}
mint operator()(int i) {
while (n <= i) d.push_back(-d[mod%n]*(mod/n)), ++n;
return d[i];
}
mint operator[](int i) const { return d[i];}
} invs;
struct modfact {
int n; vector<mint> d;
modfact(): n(2), d({1,1}) {}
mint operator()(int i) {
while (n <= i) d.push_back(d.back()*n), ++n;
return d[i];
}
mint operator[](int i) const { return d[i];}
} facts;
struct modfactinv {
int n; vector<mint> d;
modfactinv(): n(2), d({1,1}) {}
mint operator()(int i) {
while (n <= i) d.push_back(d.back()*invs(n)), ++n;
return d[i];
}
mint operator[](int i) const { return d[i];}
} ifacts;
mint comb(int n, int k) {
if (n < k || k < 0) return 0;
return facts(n)*ifacts(k)*ifacts(n-k);
}
int main() {
int n, m;
cin >> n >> m;
mint ans = comb(m+n-1, 2*n-1);
cout << ans << '\n';
return 0;
}