[LOJ#3315]「ZJOI2020」抽卡

生成函数神题 QAQ,orz EI

我的多项式水平是外国人水平

题目链接

题目传送门

简要算法

概率与期望、容斥、生成函数、拉格朗日反演、牛顿迭代

\(O(m^2)\)

\(O(m^2)\) 做法有很多,如 min-max 容斥,下面介绍一种看上去比较有优化空间的做法。

对于 \(S\subseteq\{1,2,\dots,m\}\),定义 \(end_S=0/1\) 表示 \(S\) 是否存在 \(k\) 个连续的数。

考虑每一轮的贡献,第 \(x\) 轮的贡献就是前 \(x\) 轮操作之后不会到达终态的概率:

\[ans=\sum_{end_S=0}\sum_{x\ge 0}P(x轮之后选过的数组成的集合恰好为S) \]

对于 \(x\) 轮之后选过的数组成的集合恰好\(S\) 的概率,考虑容斥计算:

\[P(x轮之后选过的数组成的集合恰好为S)=\sum_{T\subseteq S}(-1)^{|S|-|T|}(\frac{\sum_{i\in T}is_i}m)^x \]

其中 \(is_i\) 表示是否存在编号为 \(i\) 的卡。

\[ans=\sum_{end_S=0}\sum_{T\subseteq S}(-1)^{|S|-|T|}\sum_{x\ge 0}(\frac{\sum_{i\in T}is_i}m)^x=\sum_{end_S=0}\sum_{T\subseteq S}(-1)^{|S|-|T|}\frac m{m-\sum_{i\in T}is_i} \]

\(w_i(x)=x^{is_i}-1\)\(G(x)=\sum_{end_S=0}\prod_{i\in S}w_i(x)\),则答案为 \(\sum_{i=0}^{m-1}\frac m{m-i}[x^i]G(x)\)

由于 \(is_i=0\)\(w_i(x)=0\)\(is_i=1\)\(w_i(x)=x-1\),故只要先 DP \(f_{i,j}\) 表示前 \(i\) 种编号选出 \(j\) 个均为 \(is=1\) 的方案数(转移可以容斥掉最后一段长为 \(k\) 的方案),则 \(G(x)=\sum_{i\ge 0}f_{\max,i}(x-1)^i\),可以直接计算。

Solution by EI

对于 \(G(x)=\sum_{i\ge 0}f_{\max,i}(x-1)^i\) 的每一项,注意到 \([x^i]G(x)=\sum_{j\ge i}(-1)^{j-i}\binom jif_{\max,j}\),可以一次卷积求出。

对于上面的 DP,实际上可以把输入的 \(a\) 数组排序之后分成一些值域连续段,求出每个连续段(\(is\) 全为 \(1\))中选出 \(0,1,\dots\) 个元素的方案数,最后用一次分治 NTT \(O(m\log^2m)\) 求出。

现在要解决的问题就是给定 \(n\),如何对于每个 \(i=0,1,\dots,n\) 计算出在 \(n\) 个元素中选出 \(i\) 个使得没有任意连续的 \(k\) 个元素被选出的方案数。

可以转化成对于每个 \(i=0,1,\dots,n\) 计算出把 \(n+1\) 拆分成 \(n+1-i\) 个不超过 \(k\) 的正整数之和的方案数。转化方法即为增加一个不能选的元素 \(n+1\),以所有不选的元素为右端点,把该元素左边有被选上的一段元素并起来作为一段。

也就是对于任意 \(1\le m\le n+1\) 求出:

\[[x^{n+1}](\sum_{i=1}^kx_i)^m \]

也就是求二元生成函数:

\[\frac1{1-u(\sum_{i=1}^kx_i)} \]

\(x^{n+1}\) 次项。

对于只能求某一项的问题我们通常考虑拉格朗日反演,设 \(G(x)=\sum_{i=1}^kx_i\)\(G(x)\) 的复合逆为 \(G^{-1}(x)\),我们有:

\[[x^{n+1}]\frac1{1-u(\sum_{i=1}^kx_i)}=\frac1{n+1}[x^n]\frac u{(1-ux)^2}(\frac x{G^{-1}(x)})^{n+1} \]

由于 \(\frac u{(1-ux)^2}\)\(u\) 的次数总是比 \(x\) 的次数多 \(1\),故如果求出了 \((\frac x{G^{-1}(x)})^{n+1}\),就能枚举 \(\frac u{(1-ux)^2}\)\(x\) 的次数计算这两个式子积的第 \(n\) 项了。

现在要求的就是 \(F(x)=G^{-1}(x)\)。由于 \(G(x)=\sum_{i=1}^kx_i=\frac{x-x^{k+1}}{1-x}\),故我们有:

\[\frac{F(x)-F^{k+1}(x)}{1-F(x)}=x \]

\[(1+x)F(x)-F^{k+1}(x)-x=0 \]

可以牛顿迭代。

总复杂度 \(O(m\log^2m)\),瓶颈在分治 NTT,但牛顿迭代部分的常数还不止一个 log

Code

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

const int N = 3e6 + 5, djq = 998244353;

int n, m, rev[N], yg[N], a[N], b[N], ff, tot, t1[N], t2[N], t3[N], t4[N], inv[N], f[N],
t5[N], t6[N], t7[N], cnt[N], len, fac[N], invf[N], ans;
std::vector<int> A[N];

int qpow(int a, int b)
{
	int res = 1;
	while (b)
	{
		if (b & 1) res = 1ll * res * a % djq;
		a = 1ll * a * a % djq;
		b >>= 1;
	}
	return res;
}

inline void add(int &a, const int &b) {if ((a += b) >= djq) a -= djq;}

inline void sub(int &a, const int &b) {if ((a -= b) < 0) a += djq;}

void FFT(int n, int *a, int op)
{
	for (int i = 0; i < n; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
	yg[n] = qpow(1312005, (djq - 1) / n * ((n + op) % n));
	for (int i = n >> 1; i; i >>= 1)
		yg[i] = 1ll * yg[i << 1] * yg[i << 1] % djq;
	for (int k = 1; k < n; k <<= 1)
	{
		int x = yg[k << 1];
		for (int i = 0; i < n; i += k << 1)
		{
			int w = 1;
			for (int j = 0, *f1 = a + i, *f2 = a + i + k; j < k; j++, f1++, f2++)
			{
				int u = *f1, v = 1ll * w * (*f2) % djq;
				add(*f1 = u, v); sub(*f2 = u, v);
				w = 1ll * w * x % djq;
			}
		}
	}
	if (op == -1)
	{
		int gg = qpow(n, djq - 2);
		for (int i = 0; i < n; i++) a[i] = 1ll * a[i] * gg % djq;
	}
}

void nealchen(int n)
{
	ff = 1; tot = 0;
	while (ff < n) ff <<= 1, tot++;
	for (int i = 0; i < ff; i++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << tot - 1);
}

void getinv(int n, int *a, int *b)
{
	b[0] = 1;
	for (int k = 1; k <= n; k <<= 1)
	{
		nealchen(k << 2);
		for (int i = k; i < ff; i++) b[i] = 0;
		for (int i = 0; i < ff; i++) t1[i] = i <= n && i < (k << 1) ? a[i] : 0;
		FFT(ff, b, 1); FFT(ff, t1, 1);
		for (int i = 0; i < ff; i++) b[i] = (2ll - 1ll * t1[i] * b[i] % djq
			+ djq) * b[i] % djq;
		FFT(ff, b, -1);
	}
}

void getln(int n, int *a, int *b)
{
	getinv(n, a, t2); b[n] = 0; nealchen(n << 1 | 1);
	for (int i = 1; i <= n; i++) b[i - 1] = 1ll * i * a[i] % djq;
	for (int i = n + 1; i < ff; i++) t2[i] = b[i] = 0;
	FFT(ff, b, 1); FFT(ff, t2, 1);
	for (int i = 0; i < ff; i++) b[i] = 1ll * b[i] * t2[i] % djq;
	FFT(ff, b, -1);
	for (int i = n; i >= 1; i--) b[i] = 1ll * b[i - 1] * inv[i] % djq;
	b[0] = 0;
}

void getexp(int n, int *a, int *b)
{
	b[0] = 1;
	for (int k = 1; k <= n; k <<= 1)
	{
		for (int i = k; i < (k << 2); i++) b[i] = 0;
		getln((k << 1) - 1, b, t3); nealchen(k << 2);
		for (int i = 0; i < ff; i++)
		{
			if (i >= (k << 1)) {t2[i] = 0; continue;}
			t2[i] = i <= n ? a[i] : 0; sub(t2[i], t3[i]); if (!i) add(t2[i], 1);
		}
		FFT(ff, b, 1); FFT(ff, t2, 1);
		for (int i = 0; i < ff; i++) b[i] = 1ll * b[i] * t2[i] % djq;
		FFT(ff, b, -1);
	}
}

void getpow(int n, int k, int *a, int *b)
{
	getln(n, a, t4);
	for (int i = 0; i <= n; i++) t4[i] = 1ll * t4[i] * k % djq;
	getexp(n, t4, b);
}

void calc(int n, int *a)
{
	for (int i = 0; i <= n; i++) t5[i] = f[i];
	getpow(n, n + 1, t5, t6);
	for (int i = 0; i <= n; i++)
		a[n - i] = 1ll * inv[n + 1] * (i + 1) % djq * t6[n - i] % djq;
}

std::vector<int> polymul(std::vector<int> a, std::vector<int> b)
{
	int n = a.size(), m = b.size(); nealchen(n + m - 1);
	for (int i = 0; i < ff; i++) t1[i] = i < n ? a[i] : 0, t2[i] = i < m ? b[i] : 0;
	FFT(ff, t1, 1); FFT(ff, t2, 1);
	for (int i = 0; i < ff; i++) t1[i] = 1ll * t1[i] * t2[i] % djq;
	FFT(ff, t1, -1); std::vector<int> res;
	for (int i = 0; i < n + m - 1; i++) res.push_back(t1[i]);
	return res;
}

std::vector<int> nealchen2003(int l, int r)
{
	if (l == r) return A[l];
	int mid = l + r >> 1;
	return polymul(nealchen2003(l, mid), nealchen2003(mid + 1, r));
}

int main()
{
	read(n); read(m); inv[1] = f[0] = fac[0] = invf[0] = 1;
	for (int i = 2; i <= n + 1; i++)
		inv[i] = 1ll * (djq - djq / i) * inv[djq % i] % djq;
	for (int k = 1; k <= n; k <<= 1)
	{
		getpow((k << 1) - 1, m, f, t5); nealchen(k << 2);
		for (int i = k << 1; i < ff; i++) t5[i] = 0;
		for (int i = 0; i < ff; i++) t6[i] = f[i], t7[i] = t5[i];
		FFT(ff, t6, 1); FFT(ff, t7, 1);
		for (int i = 0; i < ff; i++) t6[i] = 1ll * t6[i] * t7[i] % djq;
		FFT(ff, t6, -1);
		for (int i = k << 1; i < ff; i++) t6[i] = 0;
		for (int i = (k << 1) - 1; i >= 0; i--)
			t6[i] = i >= m ? (djq - t6[i - m]) % djq : 0,
			t5[i] = i >= m ? (1ll * djq * djq - 1ll * (m + 1) * t5[i - m]) % djq : 0;
		add(t5[1], 1); add(t5[0], 1); sub(t6[0], 1);
		for (int i = 0; i < k; i++) add(t6[i + 1], f[i]), add(t6[i], f[i]);
		getinv((k << 1) - 1, t5, t7); nealchen(k << 2);
		for (int i = k << 1; i < ff; i++) t6[i] = t7[i] = 0;
		FFT(ff, t6, 1); FFT(ff, t7, 1);
		for (int i = 0; i < ff; i++) t6[i] = 1ll * t6[i] * t7[i] % djq;
		FFT(ff, t6, -1);
		for (int i = 0; i < (k << 1); i++) sub(f[i], t6[i]);
	}
	getinv(n, f, t5); for (int i = 0; i <= n; i++) f[i] = t5[i];
	for (int i = 1; i <= n; i++) read(a[i]); std::sort(a + 1, a + n + 1);
	for (int i = 1; i <= n; i++)
	{
		if (i == 1 || a[i] > a[i - 1] + 1) len++;
		cnt[len]++;
	}
	for (int i = 1; i <= len; i++)
	{
		calc(cnt[i], t7);
		for (int j = 0; j <= cnt[i]; j++) A[i].push_back(t7[j]);
	}
	std::vector<int> nc = nealchen2003(1, len);
	for (int i = 0; i <= n; i++) t1[i] = nc[i];
	for (int i = 1; i <= n; i++) fac[i] = 1ll * fac[i - 1] * i % djq,
		invf[i] = 1ll * invf[i - 1] * inv[i] % djq;
	for (int i = 0; i <= n; i++)
	{
		t1[i] = 1ll * t1[i] * fac[i] % djq;
		if (t2[n - i] = invf[i], i & 1) t2[n - i] = djq - t2[n - i];
	}
	nealchen(n << 1 | 1);
	for (int i = n + 1; i < ff; i++) t1[i] = t2[i] = 0;
	FFT(ff, t1, 1); FFT(ff, t2, 1);
	for (int i = 0; i < ff; i++) t1[i] = 1ll * t1[i] * t2[i] % djq;
	FFT(ff, t1, -1);
	for (int i = 0; i < n; i++) ans = (1ll * invf[i] * t1[n + i]
		% djq * n % djq * inv[n - i] + ans) % djq;
	return std::cout << ans << std::endl, 0;
}
posted @ 2020-07-13 21:23  epic01  阅读(460)  评论(0编辑  收藏  举报