【LOJ #2320】「清华集训 2017」生成树计数

Description

题目链接:

在一个 \(s\) 个点的图中,存在 \(s-n\) 条边,使图中形成了 \(n\) 个连通块,第 \(i\) 个连通块中有 \(a_i\) 个点。

现在我们需要再连接 \(n-1\) 条边,使该图变成一棵树。对一种连边方案,设原图中第 \(i\) 个连通块连出了 \(d_i\) 条边,那么这棵树 \(T\) 的价值为:

\[\mathrm{val}(T) = \left(\prod_{i=1}^{n} {d_i}^m\right)\left(\sum_{i=1}^{n} {d_i}^m\right) \]

你的任务是求出所有可能的生成树的价值之和,对 \(998244353\) 取模。

\(n \leq 3\times 10^4,m \leq 30\)

时空限制:\(\texttt{5s/1GB}\)

Solution

算法一

由于我比较菜,所以想了半天才会这个暴力。

将每个连通块看成一个点,首先我们知道 Prufer 序列中每个点的出现次数就是度数减一,因此我们不妨考虑枚举度数序列计算。

考虑在两个大小分别为 \(a\)\(b\) 的连通块之间连边有 \(a\cdot b\) 种选择,因此我们把所有边的贡献相乘,所以每种连通块的生成树对应的原树的方案数为 \(\prod_{i=1}^na_i^{d_i}\)

\(q_i\) 表示 Prufer 序列中 \(i\) 的出现次数,即 \(q_i=d_i-1\)。如果确定了一个 \(\sum q_i=n-2\),那么我们有

\[\text{ans}=\sum_{\sum q_i=n-2}\frac{(n-2)!}{\prod_{i=1}^nq_i!}\prod_{i=1}^na_i^{q_i+1}\left(\prod_{i=1}^{n} {(q_i+1)}^m\right)\left(\sum_{i=1}^{n} {(q_i+1)}^m\right) \]

这个式子只需要 \(q_i\) 的信息即可计算,我们仔细观察可以发现这个式子是可以 DP 的。

首先我们将奇怪的项先提出来,得到

\[\text{ans}=(n-2)!\sum_{\sum q_i=n-2}\prod_{i=1}^n\frac{a_i^{q_i+1}}{q_i!}\left(\prod_{i=1}^{n} {(q_i+1)}^m\right)\left(\sum_{i=1}^{n} {(q_i+1)}^m\right) \]

考虑当前考虑到前 \(n\) 个点有 \(\sum_{i=1}^nq_i=s\),需要考虑的式子是下面这样的,不妨设它为 \(g(n,s)\)

\[g(n,s)=\sum_{\sum_{i=1}^nq_i=s}\prod_{i=1}^n\frac{a_i^{q_i+1}}{q_i!}\left(\prod_{i=1}^{n} {(q_i+1)}^m\right)\left(\sum_{i=1}^{n} {(q_i+1)}^m\right) \]

那么考虑新加入一个 \(q_{n+1}=k\),这个式子就变为

\[\frac{(k+1)^m\cdot a_{n+1}^{k+1}}{k!}\sum_{\sum_{i=1}^nq_i=s}\prod_{i=1}^n\frac{a_i^{q_i+1}}{q_i!}\left(\prod_{i=1}^{n} {(q_i+1)}^m\right)\left(\sum_{i=1}^{n} {(q_i+1)}^m+(k+1)^m\right) \]

再设

\[f(n,s)=\sum_{\sum_{i=1}^n q_i=s}\prod_{i=1}^n\frac{a_i^{q_i+1}}{q_i!}\left(\prod_{i=1}^{n} {(q_i+1)}^m\right) \]

容易发现

\[\begin{aligned} f(n+1,s+k)&\leftarrow f(n,s)\cdot \frac{(k+1)^m\cdot a_{n+1}^{k+1}}{k!}\\ g(n+1,s+k)&\leftarrow g(n,s)\cdot \frac{(k+1)^m\cdot a_{n+1}^{k+1}}{k!} +f(n,s)\cdot\frac{(k+1)^{2m}\cdot a_{n+1}^{k+1}}{k!} \end{aligned} \]

边界是 \(f(0,0)=1,g(0,0)=0\),这样我们就可以 \(\mathcal O(n^3)\) DP 了。

期望得分 \(20\) 分。

算法二

我们仔细观察,设 \(f(i,*),g(i,*)\) 的生成函数分别为 \(F_i(x),G_i(x)\),那么我们有

\[\begin{aligned} F_i(x)&=F_{i-1}(x)\cdot\left(\sum_{j=0}^{n-1}\frac{(j+1)^ma_i^{j+1}}{j!}x^j\right)\\ G_i(x)&=G_{i-1}(x)\cdot\left(\sum_{j=0}^{n-1}\frac{(j+1)^ma_i^{j+1}}{j!}x^j\right)+F_{i-1}(x)\cdot\left(\sum_{j=0}^{n-1}\frac{(j+1)^{2m}{j+1}}{j!}x^j\right) \end{aligned} \]

那么就可以 \(\mathcal O(n^2\log n)\) FFT 了,常数有点大不太能过得去,可能要优化一下常数或者用些啥技巧。

(或者可能这档分压根就不是这么做的 qwq)

期望得分 \(35\sim 40\) 分。假装它就是 \(40\) 吧。

算法三

所有 \(a_i\) 都一样的话,我们发现转移用到的生成函数也是一样的,因此不妨设

\[T_1=\sum_{j=0}^{n-1}\frac{(j+1)^ma_i^{j+1}}{j!}x^j\\ T_2=\sum_{j=0}^{n-1}\frac{(j+1)^{2m}{j+1}}{j!}x^j \]

多项式乘法是有交换律和结合律的,简单推导可以得到

\[F_i(x)=T_1^i\\ G_i(x)=i\cdot T_1^{i-1}\cdot T_2 \]

因为我们只需要 \([x^{n-2}]G_n(x)\),我们可以多项式快速幂一下。

时间复杂度就是 \(\mathcal O(n\log n)\) 或者 \(\mathcal O(n\log^2n)\)

结合算法二可以获得 \(60\) 分。

算法四

剩下的部分就是一些牛逼(套路)操作了。

仔细观察,转移用到的生成函数除了 \(a_i\),其它部分都很相似,我们不妨设

\[A(x)=\sum_{i=0}^{n-1}\frac{(i+1)^m}{i!}\\ B(x)=\sum_{i=0}^{n-1}\frac{(i+1)^{2m}}{i!} \]

那么有

\[\begin{aligned} F_i(x)&=F_{i-1}(x)\cdot a_iA(a_ix)\\ G_i(x)&=G_{i-1}(x)\cdot a_iA(a_ix)+F_{i-1}(x)\cdot a_iB(a_ix) \end{aligned} \]

简单推导可以得到

\[\begin{aligned} F_n(x)&=\prod_{i=1}^na_i\prod_{i=1}^nA(a_ix)\\ G_n(x)&=\prod_{i=1}^na_i\sum_{i=1}^n\prod_{j=1}^n\begin{cases}A(a_jx) & i \neq j\\B(a_jx) & i=j\end{cases}\\ \end{aligned} \]

\(G_n(x)\) 的表达式写得好一点是

\[G_n(x)=\prod_{i=1}^na_i\prod_{i=1}^nA(a_ix)\sum_{i=1}^n\left(\frac{B}{A}\right)(a_ix)\\ \]

显然对于某个多项式 \(F(x)\),求 \(\sum_{i=1}^nF(a_ix)\) 比求 \(\prod_{i=1}^nF(a_ix)\) 容易得多,我们考虑先求 ln 再求 exp

\[G_n(x)=\prod_{i=1}^na_i\left(e^{\sum_{i=1}^n(\ln A)(a_ix)}\sum_{i=1}^n\left(\frac{B}{A}\right)(a_ix)\right)\\ \]

整理一下,答案就是

\[\begin{aligned} \text{ans}&=(n-2)![x^{n-2}]G_n(x)\\&=(n-2)!\prod_{i=1}^na_i[x^{n-2}]\left(e^{\sum_{i=1}^n(\ln A)(a_ix)}\sum_{i=1}^n\left(\frac{B}{A}\right)(a_ix)\right) \end{aligned} \]

现在的问题转化为,对于一个多项式 \(F(x)\),求 \(\sum_{i=1}^n F(a_ix)\)

因为是求和,我们可以写成

\[\sum_{i=1}^n F(a_ix)=\sum_{i=0}^{n-1}x^i[x^i]F(x)\sum_{j=1}^na_j^i \]

那么现在的问题就是,对于每个 \(i\),求出 \(\sum_{j=1}^na_j^i\)

众所周知,\(\frac{1}{1-ax}=\sum_{i\geq0}a^ix^i\),因此上面的问题可以有如下转化

\[\sum_{j=1}^na_j^i=[x^i]\sum_{j=1}^n\frac{1}{1-a_jx} \]

这是个经典问题。因为问题规模不允许我们对于每个 \(1-a_jx\) 求逆后相加,所以我们考虑直接从分式入手。我们尝试分治这个和式,然后合并两边的分式的时候,就模拟分式通分后相加的过程

这样能保证分治的时候,该区间的多项式次数为该区间长度,从而保证复杂度。

至此我们就解决了这个问题,时间复杂度 \(\mathcal O(n\log^2n+n\log m)\)。所以 \(m\) 其实可以出到 \(10^{18}\)

注意特判 \(n=1\),否则你会在 UOJ 上获得 97 分的好分数,别问我是怎么知道的

#include <bits/stdc++.h>

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = getchar()));
	x = ch - '0'; 
	while (isdigit(ch = getchar()))
		x = x * 10 + ch - '0'; 
}

const int mod = 998244353; 

inline int qpow(int x, int y)
{
	int res = 1; 
	for (; y; y >>= 1, x = 1LL * x * x % mod)
		if (y & 1)
			res = 1LL * res * x % mod; 
	return res; 
}

inline void add(int &x, const int &y)
{
	x += y; 
	if (x >= mod)
		x -= mod; 
}

inline void dec(int &x, const int &y)
{
	x -= y;
	if (x < 0)
		x += mod; 
}

typedef std::vector<int> vi; 
typedef std::pair<vi, vi> pvi; 
#define mp(x, y) std::make_pair(x, y)

const int MaxN = 2e5 + 5; 
const int INF = 0x3f3f3f3f; 

int fac[MaxN], fac_inv[MaxN], pwm[MaxN], ind[MaxN]; 

inline void fac_init(int n)
{
	ind[1] = 1; 
	for (int i = 2; i <= n; ++i)
		ind[i] = 1LL * ind[mod % i] * (mod - mod / i) % mod; 

	fac[0] = 1; 
	for (int i = 1; i <= n; ++i)
		fac[i] = 1LL * fac[i - 1] * i % mod; 

	fac_inv[n] = qpow(fac[n], mod - 2); 
	for (int i = n - 1; i >= 0; --i)
		fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod; 
}

namespace polynomial
{
	int P, L; 
	int rev[MaxN]; 

	inline void DFT_init(int n)
	{
		P = 0, L = 1; 
		while (L < n)
			L <<= 1, ++P; 
		for (int i = 1; i < L; ++i)
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (P - 1)); 
	}
	
	inline void DFT(vi &a, int n, int opt)
	{
		for (int i = 0; i < n; ++i)
			if (i < rev[i])
				std::swap(a[i], a[rev[i]]);

		int g = opt == 1 ? 3 : (mod + 1) / 3; 
		for (int k = 1; k < n; k <<= 1)
		{
			int omega = qpow(g, (mod - 1) / (k << 1)); 
			for (int i = 0; i < n; i += k << 1)
			{
				int x = 1; 
				for (int j = 0; j < k; ++j)
				{
					int u = a[i + j]; 
					int v = 1LL * a[i + j + k] * x % mod; 
					add(a[i + j] = u, v); 
					dec(a[i + j + k] = u, v); 
					x = 1LL * x * omega % mod; 
				}
			}
		}
		if (opt == -1)
		{
			int inv = ind[n]; 
			for (int i = 0; i < n; ++i)
				a[i] = 1LL * a[i] * inv % mod; 
		}
	}

	inline vi plus(vi a, vi b)
	{
		int sze = std::max(a.size(), b.size()); 
		a.resize(sze), b.resize(sze); 

		for (int i = 0; i < sze; ++i)
			add(a[i], b[i]); 
		return a; 
	}
	inline vi mul(vi a, vi b, int lim = INF)
	{
		int sze = a.size() + b.size() - 1; 
		DFT_init(sze), a.resize(L, 0), b.resize(L, 0); 

		vi c(L); 
		DFT(a, L, 1), DFT(b, L, 1); 
		for (int i = 0; i < L; ++i)
			c[i] = 1LL * a[i] * b[i] % mod; 
		DFT(c, L, -1);

		return c.resize(std::min(sze, lim)), c; 
	}
	inline vi inverse(vi a)
	{
		int n = a.size(), m = 1; 
		vi b(1, qpow(a[0], mod - 2)), ta; 
		while (m < n)
		{
			m <<= 1; 
			DFT_init(m << 1); 

			b.resize(L, 0); 
			(ta = a).resize(m); 
			ta.resize(L, 0); 

			DFT(b, L, 1), DFT(ta, L, 1); 
			for (int i = 0; i < L; ++i)
				b[i] = 1LL * b[i] * (mod + 2 - 1LL * ta[i] * b[i] % mod) % mod; 
			DFT(b, L, -1); 

			b.resize(m, 0); 
		}
		return b.resize(n), b; 
	}
	inline vi derivative(vi a)
	{
		vi res(0); 
		for (int i = 1, lim = a.size(); i < lim; ++i)
			res.push_back(1LL * i * a[i] % mod); 
		return res; 
	}
	inline vi anti_derivative(vi a)
	{
		vi res(1, 0); 
		for (int i = 0, lim = a.size(); i < lim; ++i)
			res.push_back(1LL * a[i] * ind[i + 1] % mod); 
		return res; 
	}
	inline vi ln(vi a)
	{
		return anti_derivative(mul(derivative(a), inverse(a), a.size() - 1)); 
	}
	inline vi exp(vi a)
	{
		int n = a.size(), m = 1; 
		vi b(1, 1), ta; 
		while (m < n)
		{
			m <<= 1; 

			b.resize(m, 0); 
			vi ln_b = ln(b); 

			(ta = a).resize(m); 
			add(ta[0], 1); 
			for (int i = 0; i < m; ++i)
				dec(ta[i], ln_b[i]); 
			b = mul(b, ta, m); 
		}
		return b.resize(n), b; 
	}
}

vi sum; 
int n, m; 
int a[MaxN]; 

inline pvi solve(int l, int r)
{
	using namespace polynomial; 
	if (l == r)
	{
		vi t(1, 1); t.push_back(mod - a[l]); 
		return mp(vi(1, 1), t); 
	}
	int mid = (l + r) >> 1; 
	pvi lef = solve(l, mid), rit = solve(mid + 1, r); 
	return mp(plus(mul(lef.first, rit.second), mul(rit.first, lef.second)), mul(lef.second, rit.second)); 
}

inline vi get_sum(vi a)
{
	vi res(0); int n = a.size(); 
	for (int i = 0; i < n; ++i)
		res.push_back(1LL * a[i] * sum[i] % mod); 
	return res; 
}

int main()
{
	read(n), read(m), fac_init(MaxN - 1); 
	for (int i = 0; i <= (n << 1); ++i)
		pwm[i] = qpow(i, m); 

	int prod = 1; 
	for (int i = 1; i <= n; ++i)
	{
		read(a[i]);
		prod = 1LL * prod * a[i] % mod; 
	}

	if (n == 1)
		return puts(m ? "0" : "1"), 0; 

	using namespace polynomial; 

	pvi t = solve(1, n); 
	sum = mul(t.first, inverse(t.second), n - 1); 

	vi A(0), B(0); 
	for (int i = 0; i < n - 1; ++i)
	{
		A.push_back(1LL * pwm[i + 1] * fac_inv[i] % mod); 
		B.push_back(1LL * pwm[i + 1] * pwm[i + 1] % mod * fac_inv[i] % mod); 
	}
	B = get_sum(mul(B, inverse(A), n - 1)); 
	A = exp(get_sum(ln(A))); 

	int res = mul(A, B)[n - 2]; 
	std::cout << 1LL * fac[n - 2] * prod % mod * res % mod << '\n'; 

	return 0; 
}
posted @ 2020-02-06 12:13  changle_cyx  阅读(316)  评论(0编辑  收藏  举报