LOJ#6703. 小 Q 的序列

题面

题解

\(f_{i, j}\) 为前 \(i\) 个数选了 \(j\) 个的权值和,那么有:

\[f_{i, j} = f_{i - 1, j} + (a_i + j) f_{i - 1, j - 1} \]

\(F_i(x) = \sum_j f_{i, j} x^j\),于是可以得出 \(F_i = (1 + x + a_ix) F_{i - 1} + x^2 F'_{i-1}\)

然后发现方程里面出现了 \(x^2 F'_{i-1}\) 这种极度不友好的项,虽然能做但是有亿点点麻烦。

考虑将 \(f_{i, j}\) 的第二维 reverse 一下,这样子就有:

\[f_{i, j} = f_{i - 1, j - 1} + (a_i + i - j) f_{i - 1, j} \]

\(b_i = a_i + i\),那么 dp 用生成函数的形式表示就是

\[F_i = xF_{i - 1} + b_i F_{i - 1} - xF'_{i-1} \]

观察到,如果没有 \(xF_{i - 1}\) 这一项,那么转移是非常 trivial 的,所以说考虑将 \(F_i\) 的转移方程向这个方向靠拢。

\(H_i = F_iG\),那么我们希望能够有:

\[H_i = b_i H_{i-1} - xH'_{i-1} \]

展开得

\[F_iG = b_iF_{i-1}G - x(F'_{i-1}G+F_{i-1}G') \]

而由原来的方程可知

\[F_iG = xF_{i - 1}G + b_i F_{i - 1}G - xF'_{i-1}G \]

那么要满足 \(H_i\) 的转移方程成立的条件,只需要 \(xF_{i-1}G = -xF_{i-1}G'\) 即可,即 \(G(x) = e^{-x}\)

这样,如果知道了 \(H_n\),那么 \(F_n = e^x H_n\)

\(H_i\) 的生成函数递推式展开就可以知道 \(h_{i, j} = (b_i - j) h_{i - 1, j}\),就有 \(h_{0, j} = \frac {(-1)^j} {j!}\)\(h_{n, j} = h_{0, j} \prod_{1 \leq i \leq n} (b_i - j)\),多点求值即可。

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)

inline int read()
{
	int data = 0, w = 1; char ch = getchar();
	while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
	if (ch == '-') w = -1, ch = getchar();
	while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
	return data * w;
}

const int N(3e5 + 10), Mod(998244353);
inline int upd(const int &x) { return x + (x >> 31 & Mod); }
int fastpow(int x, int y)
{
	int ans = 1;
	for (; y; y >>= 1, x = 1ll * x * x % Mod)
		if (y & 1) ans = 1ll * ans * x % Mod;
	return ans;
}

namespace Poly
{
	int w[N], L, invL, Len = 1;
	void Init(int n)
	{
		for (L = 1; L < n; L <<= 1); invL = fastpow(L, Mod - 2);
		for (int &i = Len, t; i < L; i <<= 1)
		{
			w[i] = 1, t = fastpow(3, Mod / (i << 1));
			for (int j = 1; j < i; j++) w[i + j] = 1ll * w[i + j - 1] * t % Mod;
		}
	}

	void DFT(int *p)
	{
		for (int i = L >> 1, s = L; i; i >>= 1, s >>= 1)
			for (int j = 0; j < L; j += s) for (int k = 0, o = i; k < i; ++k, ++o)
			{
				int x = p[j + k], y = p[i + j + k];
				p[j + k] = upd(x + y - Mod), p[i + j + k] = 1ll * w[o] * upd(x - y) % Mod;
			}
	}

	void IDFT(int *p)
	{
		for (int i = 1, s = 2; i < L; i <<= 1, s <<= 1)
			for (int j = 0; j < L; j += s) for (int k = 0, o = i; k < i; ++k, ++o)
			{
				int x = p[j + k], y = 1ll * w[o] * p[i + j + k] % Mod;
				p[j + k] = upd(x + y - Mod), p[i + j + k] = upd(x - y);
			}
		std::reverse(p + 1, p + L);
		for (int i = 0; i < L; i++) p[i] = 1ll * p[i] * invL % Mod;
	}

	void Inv(int *a, int *b, int n)
	{
		if (n == 1) return (void) (*b = fastpow(*a, Mod - 2));
		static int c[N], d[N]; Inv(a, b, (n + 1) >> 1), Init(n * 1.5 + 0.5);
		std::memset(c, 0, L << 2), std::memcpy(c, a, n << 2), DFT(c);
		std::memset(d, 0, L << 2), std::memcpy(d, b, n << 2), DFT(d);
		for (int i = 0; i < L; i++) c[i] = 1ll * c[i] * d[i] % Mod * d[i] % Mod; IDFT(c);
		for (int i = (n + 1) >> 1; i < n; i++) b[i] = upd(-c[i]);
	}

	void Mul(const int *a, const int *b, int *c, int n, int m)
	{
		static int f[N], g[N]; Init(n + m - 1);
		std::memset(f, 0, L << 2), std::memcpy(f, a, n << 2), DFT(f);
		std::memset(g, 0, L << 2), std::memcpy(g, b, m << 2), DFT(g);
		for (int i = 0; i < L; i++) f[i] = 1ll * f[i] * g[i] % Mod; IDFT(f);
		std::memcpy(c, f, (n + m - 1) << 2);
	}

	void MulT(const int *a, const int *b, int *c, int n, int m, int k)
	{
		static int f[N], g[N]; std::memcpy(f, a, n << 2), std::memcpy(g, b, m << 2);
		std::reverse(g, g + m), Mul(f, g, f, n, m), std::memcpy(c, f + m - 1, k << 2);
	}
}

inline void Mul(const std::vector<int> &f, const std::vector<int> &g, std::vector<int> &h)
	{ Poly::Mul(&f[0], &g[0], &h[0], f.size(), g.size()); }
inline void MulT(const std::vector<int> &f, const std::vector<int> &g, std::vector<int> &h)
	{ Poly::MulT(&f[0], &g[0], &h[0], f.size(), g.size(), h.size()); }

std::vector<int> v[N << 2], w[N << 2];
void getPoly(const int *a, int x, int l, int r)
{
	if (l == r) return (void) (v[x] = {1, upd(-a[l])}, w[x] = {0});
	int mid = (l + r) >> 1, ls = x << 1, rs = ls | 1;
	getPoly(a, ls, l, mid), getPoly(a, rs, mid + 1, r);
	v[x].resize(r - l + 2), w[x].resize(r - l + 1), Mul(v[ls], v[rs], v[x]);
}

void Div(int *ans, int x, int l, int r)
{
	if (l == r) return (void) (ans[l] = w[x].front());
	int mid = (l + r) >> 1, ls = x << 1, rs = ls | 1;
	MulT(w[x], v[ls], w[rs]), MulT(w[x], v[rs], w[ls]);
	Div(ans, ls, l, mid), Div(ans, rs, mid + 1, r);
}

void Solve(const int *f, const int *a, int *ans, int n)
{
	static int g[N]; getPoly(a, 1, 1, n - 1), Poly::Inv(&v[1][0], g, n);
	Poly::MulT(f, g, &w[1][0], n, n, w[1].size()), Div(ans, 1, 1, n - 1);
}

int n, a[N], b[N], ans[N], fac[N], inv[N];
std::vector<int> Prod(int l = 1, int r = n)
{
	if (l == r) return {a[l], Mod - 1}; int mid = (l + r) >> 1; std::vector<int> res(r - l + 2);
	return Mul(Prod(l, mid), Prod(mid + 1, r), res), res;
}

int main()
{
#ifndef ONLINE_JUDGE
	file(cpp);
#endif
	n = read(), ans[0] = fac[0] = 1;
	for (int i = 1; i <= n; i++) fac[i] = 1ll * fac[i - 1] * i % Mod;
	inv[n] = fastpow(fac[n], Mod - 2);
	for (int i = n; i; i--) inv[i - 1] = 1ll * inv[i] * i % Mod;
	for (int i = 1; i <= n; i++) a[i] = upd(read() + i - Mod), b[i] = i;
	for (int i = 1; i <= n; i++) ans[0] = 1ll * ans[0] * a[i] % Mod;
	Solve(&Prod()[0], b, ans, n + 1);
	for (int i = 0; i <= n; i++)
		if (i & 1) ans[i] = upd(-1ll * ans[i] * inv[i] % Mod);
		else ans[i] = 1ll * ans[i] * inv[i] % Mod;
	std::vector<int> f(ans, ans + n + 1), g(inv, inv + n + 1), h(f.size() + g.size() - 1);
	Mul(f, g, h); int res = 0;
	for (int i = 0; i < n; i++) res = upd(res + h[i] - Mod);
	printf("%d\n", res);
	return 0;
}
posted @ 2021-01-26 11:51  xgzc  阅读(207)  评论(0编辑  收藏  举报