JZOJ 6757 2020.07.21【NOI2020】模拟T3 (至少容斥+OGF+NTT)

题目大意:

一个序列\(a[1..n](1 \le n \le N)\),满足:
\(a[i] \in [1,m]\)
\(a[i]<a[i+1]\)的对数\(=k\)

求方案数。

\(N,(N-k+1)*m \le 2^{20}\)

题解:

考虑容斥,恰好\(i\)个小于,变为至少\(i\)个小于,其它地方任意,容斥系数由二项式反演不难的是\(\binom{i}{k}*(-1)^{i-k}\)

这样我们想想,发现是\(n\)个数划分成\(n-i\)段,每一段里都是全部\(<\)的限制。

那么系数是:\(((1+x)^m-1)^{n-i}[x^n]\)

那么:
\(Ans=\sum_{i=k}^{n-1} \binom{i}{k}*(-1)^{i-k} \sum_{n=1}^{N} ((1+x)^m-1)^{n-i}[x^n]\)
最奇妙的一步:
\(=\sum_{i=k}^{n-1} \binom{i}{k}*(-1)^{i-k} \sum_{n=1}^{N} ((1+x)^m-1)^{n-i}*x^{N-n}[x^N]\)

\(f=((1+x)^m-1)\)
化成:
\(=\sum_{i=k}^{n-1} \binom{i}{k}*(-1)^{i-k} \sum_{n=1}^{N} f^{n-i}*x^{N-n}[x^N]\)

发现是个等比数列求和:
\(=\sum_{i=k}^{n-1} \binom{i}{k}*(-1)^{i-k} *\frac{f^{N+1-i}-x^{N+1-i}}{f-x}~[x^N]\)

注意特判\(m=1\)

分母、分子的第\(0\)项都是\(0\),提前除掉,不然没法求逆,分子也要多算一项。

分母可以求逆后分子和分母就可以分开了,分子\(x^{N+1}-i\)易得,考虑\(f^{N+1-i}\)

\(\sum_{i=k}^{n-1} \binom{i}{k}*(-1)^{i-k} * f^{N+1-i}\)
\(=\sum_{i=k}^{n-1} \binom{i}{k}*(-1)^{i-k} * ((1+x)^m - 1)^{N+1-i}\)
\(=\sum_{i=k}^{n-1} \binom{i}{k}*(-1)^{i-k} * \sum_{j=0}^{N+1-i} \binom{N+1-i}{j} (-1)^{N+1-i-j} (1+x)^{mj} [x^{0..N+1}]\)

这里已经可以直接NTT,就是把系数先全部给\(j\),再给\(x^{0..N+1}\)即可。

Code:

#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i ,x, y) for(int i = x, _b = y; i <  _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
#define ll long long
#define pp printf
#define hh pp("\n")
using namespace std;

const int mo = 998244353;

ll ksm(ll x, ll y) {
	ll s = 1;
	for(; y; y /= 2, x = x * x % mo)
		if(y & 1) s = s * x % mo;
	return s;
}

#define V vector<ll>
#define si size()
#define re resize

namespace ntt {
	const int nm = 1 << 22;
	int r[nm]; ll w[nm];
	void build() {
		for(int i = 1; i < nm; i *= 2) {
			w[i] = 1; ll v = ksm(3, (mo - 1) / 2 / i);
			ff(j, 1, i) w[i + j] = w[i + j - 1] * v % mo;
 		}
	}
	void dft(ll *a, int n, int f) {
		ff(i, 0, n)	{
			r[i] = r[i / 2] / 2 + (i & 1) * (n / 2);
			if(i < r[i]) swap(a[i], a[r[i]]);
		} ll v;
		for(int i = 1; i < n; i *= 2) for(int j = 0; j < n; j += 2 * i) ff(k, 0, i)
			v = a[i + j + k] * w[i + k], a[i + j + k] = (a[j + k] - v) % mo, a[j + k] = (a[j + k] + v) % mo;
		if(f == -1) {
			reverse(a + 1, a + n);
			v = ksm(n, mo - 2);
			ff(i, 0, n) a[i] = (a[i] + mo) * v % mo;
		}
	}
	ll a[nm], b[nm];
	V operator * (V p, V q) {
		int n0 = p.si + q.si - 1, n = 1;
		while(n < n0) n *= 2;
		ff(i, 0, n) a[i] = b[i] = 0;
		ff(i, 0, p.si) a[i] = p[i];
		ff(i, 0, q.si) b[i] = q[i];
		dft(a, n, 1); dft(b, n, 1);
		ff(i, 0, n) a[i] = a[i] * b[i] % mo;
		dft(a, n, -1);
		p.re(n0);
		ff(i, 0, n0) p[i] = a[i];
		return p;
	}
	void dft(V &p, int f) {
		int n = p.si;
		ff(i, 0, n) a[i] = p[i];
		dft(a, n, f);
		ff(i, 0, n) p[i] = a[i];
	}
}

using ntt :: operator * ;
using ntt :: dft;

V operator + (V a, V b) {
	a.re(max(a.si, b.si));
	ff(i, 0, b.si) a[i] = (a[i] + b[i]) % mo;
	return a;
}

V qni(V a) {
	int n0 = a.si, n = 2;
	V b; b.re(1); b[0] = ksm(a[0], mo - 2);
	for(; n < n0 * 2; n *= 2) {
		V c = a; c.re(n); c.re(2 * n);
		b.re(2 * n);
		dft(c, 1); dft(b, 1);
		ff(i, 0, 2 * n) b[i] = (2 * b[i] - c[i] * b[i] % mo * b[i]) % mo;
		dft(b, -1);
		b.re(n);
	}
	b.re(n0);
	return b;
}

const int N = (1 << 21) + 5;

int n, m, k;
ll fac[N], nf[N];

void build(int n) {
	fac[0] = 1; fo(i, 1, n) fac[i] = fac[i - 1] * i % mo;
	nf[n] = ksm(fac[n], mo - 2); fd(i, n, 1) nf[i - 1] = nf[i] * i % mo;
}

ll C(int n, int m) {
	if(n < m) return 0;
	return fac[n] * nf[m] % mo * nf[n - m] % mo;
}

int main() {
	freopen("singer.in", "r", stdin);
	freopen("singer.out", "w", stdout);
	ntt :: build();
	build(1 << 21);
	scanf("%d %d %d", &n, &m, &k); 
	
	if(m == 1) {
		pp("%d\n", k == 0);
		return 0;
	}
	
	if(k == 0) {
		ll ans = (C(n + m, m) - 1 + mo) % mo;
		pp("%lld\n", ans);
		return 0;
	}
	
	V f; f.re(m);
	fo(i, 1, m) f[i - 1] = C(m, i);
	f[0] --;
	f.re(n + 1);
	f = qni(f);
	
	V s; s.re(n + 2);
	
	V a; a.re(n + 1);
	fo(i, k, n) {
		a[n + 1 - i] = C(i, k) * ((i - k) % 2 ? -1 : 1) * fac[n + 1 - i] % mo;
	}
	V b; b.re(n + 1);
	fo(i, 0, n) b[i] = (i % 2 ? -1 : 1) * nf[i] % mo;
	reverse(b.begin(), b.end());
	
	a = a * b;
	
	int t = m * (n - k + 1);
	
	V p; p.re(t + 1);
	fo(j, 0, n + 1 - k) {
		ll xs = a[j + n] * nf[j] % mo;
		p[j * m] = fac[j * m] * xs % mo;
	}
	V q; q.re(t + 1);
	fo(i, 0, t) q[i] = nf[i];
	reverse(q.begin(), q.end());
	
	p = p * q;
	fo(k, 0, t) {
		ll xs = p[k + t] * nf[k] % mo;
		if(k > n + 1) break;
		s[k] = (s[k] + xs) % mo;
	}
	
//	s.clear(); s.re(n + 2);
//	fo(i, k, n) {
//		ll xs = C(i, k) * ((i - k) % 2 ? -1 : 1);
//		pp("%d %lld\n", i, xs);
//		fo(j, 0, n + 1 - i) {
//			ll xs2 = xs * C(n + 1 - i, j) * ((n + 1 - i - j) % 2 ? -1 : 1) % mo;
//			fo(k, 0, n + 1) s[k] = (s[k] + xs2 * C(j * m, k)) % mo;
//		}
//	}
//	pp("s =\n");
//	ff(i, 0, s.si)	pp("%lld ", s[i]); hh;
	fo(i, k, n) s[n + 1 - i] = (s[n + 1 - i] - C(i, k) * ((i - k) % 2 ? -1 : 1)) % mo;
	
	ff(i, 1, s.si) s[i - 1] = s[i];
	s.re(s.si - 1);
	
	s = s * f;
	
	
	ll ans = s[n];
	ans = (ans % mo + mo) % mo;
	
	pp("%lld\n", ans);
}
posted @ 2020-07-21 21:09  Cold_Chair  阅读(403)  评论(0编辑  收藏  举报