【洛谷4705】玩游戏(多项式)

s*s 怎么整天给我推荐神奇的多项式题,我又不会 ……

题目

洛谷 4705

分析

我感觉这题不是很好理解,我狠狠恶补了一波多项式知识。

当然主要原因还是我除了裸的多项式乘法以外压根没做过什么多项式题,乘法以上的操作(求逆之类)都只写过板子 ……

前置知识(这题要用 ln ):

(说实话我现在看我将近一年四个月以前写的这些东西,感觉写得并不好,很多地方没说清楚,什么时候有空了重新写一下。)

来看这道题。根据题意可以直接列出答案的表达式:

\[ans_k=\frac{\sum_{i=1}^n\sum_{j=1}^m (a_i+b_j)^k}{nm} \]

\(ans\) 理论上长度是无穷大,本题中只需要取前 \(t+1\) 项(相当于在模 \(x^{t+1}\) 意义下)。

把它写成生成函数的形式。生成函数仅仅是为了方便描述,把数列改写成多项式的形式,自变量 \(x\) 和求得的值都没有意义,我们要求的仅仅是每一项的系数。

\[ans=\sum_{k=0}^\infin\frac{\sum_{i=1}^n\sum_{j=1}^m (a_i+b_j)^k}{nm} \]

用二项式定理展开,然后做一些基本的变换。把一些可以直接算的常数放到左边:

\[\begin{aligned} ans&=\frac{1}{nm}\sum_{k=0}^\infin x_k\sum_{i=1}^n\sum_{j=1}^m \sum_{r=0}^k C_k^r\cdot a_i^r\cdot b_j^{k-r}\\ &=\frac{1}{nm}\sum_{k=0}^\infin x_k\sum_{r=0}^k \frac{k!}{r!\cdot (k-r)!}\cdot \sum_{i=1}^n a_i^r\cdot \sum_{j=1}^m b_j^{k-r}\\ &=\frac{1}{nm}\sum_{k=0}^\infin k!\cdot x_k\sum_{r=0}^k\sum_{i=1}^n\frac{a_i^r}{r!}\cdot\sum_{j=1}^m\frac{b_j^{k-r}}{(k-r)!}\\ \end{aligned}\]

观察发现:

\[ans=A*B\cdot C \]

其中星号表示多项式卷积,点号表示多项式点积(对应位相乘)(下同),多项式 \(A\)\(B\)\(C\) 的定义如下:

\[A=\sum_{r=0}^\infin\frac{\sum_{i=1}^{n}a_i^rx^r}{r!} \]

\[B=\sum_{r=0}^\infin\frac{\sum_{i=1}^{m}b_i^rx^r}{r!} \]

\[C=\frac{1}{nm}\sum_{r=0}^\infin r!x^r \]

这三个多项式的长度和 \(ans\) 一样都是无穷大。

注意到 \(C\) 可以直接计算,\(A\)\(B\) 形式相似。因此可以把问题转化成求形如下面这个的多项式,最后给每一项除以 \(r!\) 即可:

\[\begin{aligned} F&=\sum_{r=0}^\infin\sum_{i=1}^{n}a_i^rx^r\\ &=\sum_{i=1}^{n}\sum_{r=0}^\infin a_i^rx^r \end{aligned}\]

注意到第一层求和里面的内容是一个等比数列求和。因为可以给 \(x\) 指定任意一个非 \(0\) 值,所以不妨认为 \(0<x<1\) ,这样这个无穷等比数列的和就是收敛的,极限是 \(0\) 。根据等比数列求和公式,变为:

\[F=\sum_{i=1}^n \frac{1}{1-a_ix} \]

但这样还不够。别忘了要求的不是和而是模 \(x^{t+1}\) 意义下每一项的系数。上式可进一步变为:

\[\begin{aligned} F&=\sum_{i=1}^n \frac{1}{1-a_ix}\\ &=\sum_{i=1}^n (1+\frac{a_ix}{1-a_ix})\\ &=n-x\sum_{i=1}^n\frac{-a_i}{1-a_ix}\\ &=n-x\sum_{i=1}^n (\ln(1-a_ix))'\end{aligned}\]

注意 \((\ln(A(x)))'\) 是把 \(x\) 作为自变量求导,所以要用一次链式法则。即:

\[(\ln(A(x)))'=\frac{A'(x)}{A(x)} \]

别忘了求 \(\ln(1-a_ix)\) 要在模 \(x^{t+1}\) 的意义下进行。

(我到这里的时候以为做完了就直接把博客发上去了 …… 五分钟后突然想起来时间复杂度不太对)

因为导数的和等于和的导数,对数的和等于积的对数,所以:

\[\begin{aligned} F&=n-x\sum_{i=1}^n (\ln(1-a_ix))'\\ &=n-x(\sum_{i=1}^n \ln(1-a_ix))'\\ &=n-x(\ln(\prod_{i=1}^n (1-a_ix)))'\end{aligned}\]

求积可以用分治 + NTT 解决,时间复杂度 \(O(n\log^2 n)\) 。然后照着这个式子就能算出 \(F\)

这样我们就把 \(A\)\(B\) 算出来了。直接卷起来再点乘上 \(C\) 就是答案。

(等等这题好像并不是特别复杂,为什么我做的时候纠结得要死 ……)

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
using namespace std;

namespace zyt
{
	template<typename T>
	inline bool read(T &x)
	{
		char c;
		bool f = false;
		x = 0;
		do
			c = getchar();
		while (c != EOF && c != '-' && !isdigit(c));
		if (c == EOF)
			return false;
		if (c == '-')
			f = true, c = getchar();
		do
			x = x * 10 + c - '0', c = getchar();
		while (isdigit(c));
		if (f)
			x = -x;
		return true;
	}
	template<typename T>
	inline void write(T x)
	{
		static char buf[20];
		char *pos = buf;
		if (x < 0)
			putchar('-'), x = -x;
		do
			*pos++ = x % 10 + '0';
		while (x /= 10);
		while (pos > buf)
			putchar(*--pos);
	}
	typedef long long ll;
	const int N = 1e5 + 10, P = 998244353;
	int fac[N], finv[N];
	int power(int a, int b)
	{
		int ans = 1;
		while (b)
		{
			if (b & 1)
				ans = (ll)ans * a % P;
			a = (ll)a * a % P;
			b >>= 1;
		}
		return ans;
	}
	int getinv(const int a)
	{
		return power(a, P - 2);
	}
	namespace Polynomial
	{
		const int LEN = N << 2, G = 3;
		int omega[LEN], winv[LEN], rev[LEN];
		void init(const int n, const int lg2)
		{
			int w = power(G, (P - 1) / n), wi = getinv(w);
			omega[0] = winv[0] = 1;
			for (int i = 1; i < n; i++)
			{
				omega[i] = (ll)omega[i - 1] * w % P;
				winv[i] = (ll)winv[i - 1] * wi % P;
			}
			for (int i = 0; i < n; i++)
				rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
		}
		void ntt(int *const a, const int *const w, const int n)
		{
			for (int i = 0; i < n; i++)
				if (i < rev[i])
					swap(a[i], a[rev[i]]);
			for (int l = 1; l < n; l <<= 1)
				for (int i = 0; i < n; i += (l << 1))
					for (int k = 0; k < l; k++)
					{
						int x = a[i + k], y = (ll)a[i + k + l] * w[n / (l << 1) * k] % P;
						a[i + k] = (x + y) % P;
						a[i + k + l] = (x - y + P) % P;
					}
		}
		void mul(const int *const a, const int *const b, int *const c, const int n)
		{
			static int x[LEN], y[LEN];
			int m = 1, lg2 = 0;
			while (m < n + n - 1)
				m <<= 1, ++lg2;
			memcpy(x, a, sizeof(int[n]));
			memset(x + n, 0, sizeof(int[m - n]));
			memcpy(y, b, sizeof(int[n]));
			memset(y + n, 0, sizeof(int[m - n]));
			init(m, lg2);
			ntt(x, omega, m), ntt(y, omega, m);
			for (int i = 0; i < m; i++)
				x[i] = (ll)x[i] * y[i] % P;
			ntt(x, winv, m);
			int invm = getinv(m);
			for (int i = 0; i < n; i++)
				c[i] = (ll)x[i] * invm % P;
		}
		void _inv(const int *const a, int *const b, const int n)
		{
			if (n == 1)
			{
				b[0] = 1;
				return;
			}
			static int x[LEN], y[LEN];
			_inv(a, y, (n + 1) / 2);
			int m = 1, lg2 = 0;
			while (m <= n * 2)
				m <<= 1, ++lg2;
			init(m, lg2);
			memcpy(x, a, sizeof(int[n]));
			memset(x + n, 0, sizeof(int[m - n]));
			memset(y + (n + 1) / 2, 0, sizeof(int[m - (n + 1) / 2]));
			ntt(x, omega, m), ntt(y, omega, m);
			for (int i = 0; i < m; i++)
				y[i] = ((ll)y[i] * 2LL % P - (ll)x[i] * y[i] % P * y[i] % P + P) % P;
			ntt(y, winv, m);
			int invm = getinv(m);
			for (int i = 0; i < n; i++)
				b[i] = (ll)y[i] * invm % P;
		}
		void inv(const int *const a, int *b, const int n)
		{
			static int x[LEN];
			memcpy(x, a, sizeof(int[n]));
			_inv(x, b, n);
		}
		void derivative(const int *const a, int *const b, const int n)
		{
			for (int i = 1; i < n; i++)
				b[i - 1] = (ll)a[i] * i % P;
			b[n - 1] = 0;
		}
		void integral(const int *const a, int *const b, const int n)
		{
			for (int i = n - 1; i >= 0; i--)
				b[i + 1] = (ll)a[i] * getinv(i + 1) % P;
			b[0] = 0;
		}
		void ln(const int *const a, int *const b, const int n)
		{
			static int x[LEN], y[LEN];
			derivative(a, x, n);
			inv(a, y, n - 1);
			mul(x, y, b, n - 1);
			integral(b, b, n - 1);
		}
	}
	int a[N], b[N], n, m, A[N], B[N], t;
	void solve(const int *const a, const int l, const int r, int *const f)
	{
		using namespace Polynomial;
		if (l == r)
		{
			f[0] = 1, f[1] = P - a[l];
			return;
		}
		int mid = (l + r) >> 1, *x = new int[r - l + 2], *y = new int[r - l + 2];
		solve(a, l, mid, x), solve(a, mid + 1, r, y);
		memset(x + mid - l + 2, 0, sizeof(int[r - mid]));
		memset(y + r - mid + 1, 0, sizeof(int[mid - l + 1]));
		mul(x, y, f, r - l + 2);
		delete x;
		delete y;
	}
	void cal(const int *const a, const int n, int *const f)
	{
		using namespace Polynomial;
		solve(a, 1, n, f);
		if (t >= n)
			memset(f + n + 1, 0, sizeof(int[t - n + 1]));
		ln(f, f, t + 2);
		derivative(f, f, t + 1);
		for (int i = t; i > 0; i--)
			f[i] = P - f[i - 1];
		f[0] = n;
		for (int i = 0; i <= t; i++)
			f[i] = (ll)f[i] * finv[i] % P;
	}
	void init()
	{
		fac[0] = 1;
		for (int i = 1; i < N; i++)
			fac[i] = (ll)fac[i - 1] * i % P;
		finv[N - 1] = getinv(fac[N - 1]);
		for (int i = N - 1; i > 0; i--)
			finv[i - 1] = (ll)finv[i] * i % P;
	}
	int work()
	{
		using namespace Polynomial;
		init();
		read(n), read(m);
		for (int i = 1; i <= n; i++)
			read(a[i]);
		for (int i = 1; i <= m; i++)
			read(b[i]);
		read(t);
		cal(a, n, A), cal(b, m, B);
		mul(A, B, A, t + 1);
		for (int i = 1; i <= t; i++)
			write((ll)A[i] * fac[i] % P * getinv(n) % P * getinv(m) % P), putchar('\n');
		return 0;
	}
}
int main()
{
	freopen("4705.in", "r", stdin);
	return zyt::work();
}
posted @ 2020-04-24 09:46  Inspector_Javert  阅读(145)  评论(0编辑  收藏  举报