[题解] 好好

题意

不定方程解的计数问题

\[\sum_{x_1+x_2+\cdots + x_m = N, x_i \in \N} \prod x_i ^ {K_i} \]

  • \(\sum K_i \le 10^5, m \le 10^5\)
  • \(N \le 10^7\)

(记\(0^0 = 1\)

思路

首先这个式子的组合意义就是把 \(N\) 个球分成 \(m\) 组,然后在每组中有序地选出 \(K_i\) 个元素(可以重复)。

如果 \(K_i = 1\),有个明显的组合意义就是选出一个代表元,然后两边各有一些元素,一个未知数变成两个未知数,总和减少了 \(1\)

当然 \(K_i > 1\) 也同理,可以看成选了 \(K_i\) 个代表元,但是要考虑代表元重复以及各个代表元间的顺序的问题。

\(\sum K_i = S\),枚举代表元的位置数 \(i \in [m, S]\),显然不同组(不在一个未知数内)的代表元相互独立,方案数相乘,可以用若干个 OGF 表示。

\(m\) 个代表元合并成 \(k\) 个(考虑顺序)的方案数用第二类 Stirling 数表示:

\[\begin{Bmatrix} m \\ k \end{Bmatrix} \cdot k! \]

第二类 Stirling 数的一行可以用容斥转化成卷积,见 Luogu 模板 第二类斯特林数·行,这里不讲了。

把这些 OGF 乘起来,就得到了 \(S\) 个代表元合并的方案数,剩下的用组合数就很容易数了。若有 \(i\) 个位置上有代表元,则方案数为 \(m + i\) 个非负未知数,和为 \(n - i\) 的不定方程解数计数。乘法的复杂度上界是 \(\mathcal O(n \log^2 n)\),具体这里使用类似石子合并的方法合并。

Code

多项式乘法使用 std::vector 存储,常数巨大,内存巨大

#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
#define File(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout)
typedef long long ll;
namespace io {
	const int SIZE = (1 << 21) + 1;
	char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55]; int f, qr;
	#define gc() (iS == iT ? (iT = (iS = ibuf) + fread (ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS ++)) : *iS ++)
	char getc () {return gc();}
	inline void flush () {fwrite (obuf, 1, oS - obuf, stdout); oS = obuf;}
	inline void putc (char x) {*oS ++ = x; if (oS == oT) flush ();}
	template <class I> inline void gi (I &x) {for (f = 1, c = gc(); c < '0' || c > '9'; c = gc()) if (c == '-') f = -1;for (x = 0; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15); x *= f;}
	template <class I> inline void print (I x) {if (!x) putc ('0'); if (x < 0) putc ('-'), x = -x;while (x) qu[++ qr] = x % 10 + '0',  x /= 10;while (qr) putc (qu[qr --]);}
	struct Flusher_ {~Flusher_(){flush();}}io_flusher_;
}
using io :: gi; using io :: putc; using io :: print; using io :: getc;
template<class T> void upmax(T &x, T y){x = x>y ? x : y;}
template<class T> void upmin(T &x, T y){x = x<y ? x : y;}

const int p = 998244353, G = 3;
inline int add(int x, int y){return x+y>=p ? x+y-p : x+y;}
inline int sub(int x, int y){return x-y<0 ? x-y+p : x-y;}
inline int mul(int x, int y){return 1LL * x * y % p;}
inline int power(int x, int y){
	int res = 1;
	for(; y; y>>=1, x = mul(x, x)) if(y & 1) res = mul(res, x);
	return res;
}
inline int inv(int x){return power(x, p - 2);}

const int N = 10000005, M = 100005, Len = 262144;

int fac[N + M * 5], ifac[N + M * 5];
void preC(int n){
	fac[0] = 1;
	for(int i=1; i<=n; i++) fac[i] = mul(fac[i-1], i);
	ifac[n] = inv(fac[n]);
	for(int i=n-1; i>=0; i--) ifac[i] = mul(ifac[i+1], i+1);
}
inline int C(int n, int m){return mul(fac[n], mul(ifac[m], ifac[n - m]));}
inline int P(int n, int m){return mul(fac[n], ifac[n - m]);}
inline int equation(int m, int S){
	return C(S - 1 + m, m - 1);
}

int K[M];

int m, n;

namespace polynomial{
	int w[Len], invw[Len];
	struct _polyInit{
		_polyInit(){
			w[0] = invw[0] = 1;
			w[1] = power(G, (p - 1) / Len); invw[1] = inv(w[1]);
			for(int i=2; i<Len; i++){
				w[i] = mul(w[i-1], w[1]);
				invw[i] = mul(invw[i - 1], invw[1]);
			}
		}
	}_init;
	int last = -1;
	int rev[Len];
	void pre(int n){
		if(last == n) return ;
		last = n;
		int lg = -1, nn = n;
		while(nn != 1) nn >>= 1, ++lg;
		for(int i=1; i<n; i++)
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg);
	}
	void NTT(vector<int> &f, int *w){
		int n = f.size();
		pre(n);
		for(int i=1; i<n; i++)
			if(rev[i] < i) swap(f[i], f[rev[i]]);
		for(int l=1; l<n; l<<=1){
			int step = Len / (l << 1);
			for(int j=0; j<n; j+=(l<<1))
				for(int k=0, p=0; k<l; k++, p+=step){
					int x = f[j + k], y = mul(f[j + l + k], w[p]);
					f[j + k] = add(x, y); f[j + k + l] = sub(x, y);
				}
		}
	}
};
typedef vector<int> poly;
void operator*=(poly &a, poly &b){
	using namespace polynomial;
	int la = a.size(), lb = b.size(), len = 1;
	while(len < la + lb) len <<= 1;
	a.resize(len); b.resize(len);
	NTT(a, w); NTT(b, w);
	for(int i=0; i<len; i++) a[i] = mul(a[i], b[i]);
	NTT(a, invw);
	int invlen = inv(len);
	for(int i=0; i<len; i++) a[i] = mul(a[i], invlen);
	a.resize(la + lb - 1);
}

poly F[N];

struct LenCmp{
	bool operator() (int a, int b) const {return F[a].size() > F[b].size();}
};
priority_queue<int, vector<int>, LenCmp> q;

void getStirling(poly &a, int n){
	poly b(n + 1);
	a.resize(n + 1);
	for(int i=1; i<=n; i++){
		a[i] = mul(power(i, n), ifac[i]);
		b[i] = ifac[i];
		if(i & 1) b[i] = sub(0, b[i]);
	}
	b[0] = 1;
	a *= b;
	a.resize(n + 1);
}

int solve(){
	int kcnt = 0, ksum = 0;
	for(int i=1; i<=m; i++){
		kcnt += K[i] >= 1;
		ksum += K[i];
		if(K[i] == 0) continue;
		getStirling(F[i], K[i]);
		for(int j=1; j<=K[i]; j++) F[i][j] = mul(F[i][j], fac[j]);
		q.push(i);
	}
	if(q.size() == 0) return 1;
	while(q.size() != 1){
		int x = q.top(); q.pop();
		int y = q.top(); q.pop();
		F[x] *= F[y];
		q.push(x);
	}
	vector<int> &f = F[q.top()]; q.pop();
	int res = 0;
	for(int i=kcnt; i<=ksum; i++)
		res = add(res, mul(f[i], equation(m + i, n - i)));
	return res;
}

int main(){
	gi(m); gi(n);
	preC(n + m * 3);
	for(int i=1; i<=m; i++)
		gi(K[i]);
	printf("%d\n", solve());
	return 0;
}
posted @ 2020-03-14 20:42  RiverHamster  阅读(201)  评论(0编辑  收藏  举报
\