对于所有长度为 n 且总和为 m 的任意正整数序列 a,求 aimod 998244353

限制:

  • 1n,m2×105

算法分析

做法一:积的和典型

一方面,满足 ai=m 的正整数序列个数,可以考虑在 m 个白球中插入 n1 个隔板,那么就会得到 n 个区间

另一方面,ai 可以理解为在每个区间中任选一个球(不妨将这个球染成红色)的方案数,因为不同区间的选法是独立的,所以是乘法

然后将这二者综合起来考虑,不妨把隔板换成红球(不会影响上面的选法数),这样就得到 n+m1 个球,且其中有 2n1 个红球,那么总方案数就是 (n+m12n1)

(更喜欢这种做法,虽然理解起来会比较困难)

做法二:生成函数

写成生成函数的式子就是 [xm](x+2x2+3x3+)n

f(x)=x+2x2+3x3+
(1x)f(x)=(x+2x2+3x3+)(x2+2x3+3x4+)=x+x2+x3+=x1x

 f(x)=x(1x)2

于是,[xm]f(x)n=[xm]xn(1x)2n=[xmn]1(1x)2n

这个值可以用负数的二项式写成 (1)mn(2nmn)

一般地,我们有 (nk)=(1)k(n+k1k)

所以,(1)mn(2nmn)=(m+n1mn)=(m+n12n1)

代码实现
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int mod = 998244353;
//const int mod = 1000000007;
struct mint {
ll x;
mint(ll x=0):x((x%mod+mod)%mod) {}
mint operator-() const {
return mint(-x);
}
mint& operator+=(const mint a) {
if ((x += a.x) >= mod) x -= mod;
return *this;
}
mint& operator-=(const mint a) {
if ((x += mod-a.x) >= mod) x -= mod;
return *this;
}
mint& operator*=(const mint a) {
(x *= a.x) %= mod;
return *this;
}
mint operator+(const mint a) const {
return mint(*this) += a;
}
mint operator-(const mint a) const {
return mint(*this) -= a;
}
mint operator*(const mint a) const {
return mint(*this) *= a;
}
mint pow(ll t) const {
if (!t) return 1;
mint a = pow(t>>1);
a *= a;
if (t&1) a *= *this;
return a;
}
// for prime mod
mint inv() const {
return pow(mod-2);
}
mint& operator/=(const mint a) {
return *this *= a.inv();
}
mint operator/(const mint a) const {
return mint(*this) /= a;
}
};
istream& operator>>(istream& is, mint& a) {
return is >> a.x;
}
ostream& operator<<(ostream& os, const mint& a) {
return os << a.x;
}
struct modinv {
int n; vector<mint> d;
modinv(): n(2), d({0,1}) {}
mint operator()(int i) {
while (n <= i) d.push_back(-d[mod%n]*(mod/n)), ++n;
return d[i];
}
mint operator[](int i) const { return d[i];}
} invs;
struct modfact {
int n; vector<mint> d;
modfact(): n(2), d({1,1}) {}
mint operator()(int i) {
while (n <= i) d.push_back(d.back()*n), ++n;
return d[i];
}
mint operator[](int i) const { return d[i];}
} facts;
struct modfactinv {
int n; vector<mint> d;
modfactinv(): n(2), d({1,1}) {}
mint operator()(int i) {
while (n <= i) d.push_back(d.back()*invs(n)), ++n;
return d[i];
}
mint operator[](int i) const { return d[i];}
} ifacts;
mint comb(int n, int k) {
if (n < k || k < 0) return 0;
return facts(n)*ifacts(k)*ifacts(n-k);
}
int main() {
int n, m;
cin >> n >> m;
mint ans = comb(m+n-1, 2*n-1);
cout << ans << '\n';
return 0;
}