【LOJ #2320】「清华集训 2017」生成树计数
Description
题目链接:
在一个 \(s\) 个点的图中,存在 \(s-n\) 条边,使图中形成了 \(n\) 个连通块,第 \(i\) 个连通块中有 \(a_i\) 个点。
现在我们需要再连接 \(n-1\) 条边,使该图变成一棵树。对一种连边方案,设原图中第 \(i\) 个连通块连出了 \(d_i\) 条边,那么这棵树 \(T\) 的价值为:
你的任务是求出所有可能的生成树的价值之和,对 \(998244353\) 取模。
\(n \leq 3\times 10^4,m \leq 30\)
时空限制:\(\texttt{5s/1GB}\)
Solution
算法一
由于我比较菜,所以想了半天才会这个暴力。
将每个连通块看成一个点,首先我们知道 Prufer 序列中每个点的出现次数就是度数减一,因此我们不妨考虑枚举度数序列计算。
考虑在两个大小分别为 \(a\) 和 \(b\) 的连通块之间连边有 \(a\cdot b\) 种选择,因此我们把所有边的贡献相乘,所以每种连通块的生成树对应的原树的方案数为 \(\prod_{i=1}^na_i^{d_i}\)。
设 \(q_i\) 表示 Prufer 序列中 \(i\) 的出现次数,即 \(q_i=d_i-1\)。如果确定了一个 \(\sum q_i=n-2\),那么我们有
这个式子只需要 \(q_i\) 的信息即可计算,我们仔细观察可以发现这个式子是可以 DP 的。
首先我们将奇怪的项先提出来,得到
考虑当前考虑到前 \(n\) 个点有 \(\sum_{i=1}^nq_i=s\),需要考虑的式子是下面这样的,不妨设它为 \(g(n,s)\)
那么考虑新加入一个 \(q_{n+1}=k\),这个式子就变为
再设
容易发现
边界是 \(f(0,0)=1,g(0,0)=0\),这样我们就可以 \(\mathcal O(n^3)\) DP 了。
期望得分 \(20\) 分。
算法二
我们仔细观察,设 \(f(i,*),g(i,*)\) 的生成函数分别为 \(F_i(x),G_i(x)\),那么我们有
那么就可以 \(\mathcal O(n^2\log n)\) FFT 了,常数有点大不太能过得去,可能要优化一下常数或者用些啥技巧。
(或者可能这档分压根就不是这么做的 qwq)
期望得分 \(35\sim 40\) 分。假装它就是 \(40\) 吧。
算法三
所有 \(a_i\) 都一样的话,我们发现转移用到的生成函数也是一样的,因此不妨设
多项式乘法是有交换律和结合律的,简单推导可以得到
因为我们只需要 \([x^{n-2}]G_n(x)\),我们可以多项式快速幂一下。
时间复杂度就是 \(\mathcal O(n\log n)\) 或者 \(\mathcal O(n\log^2n)\)。
结合算法二可以获得 \(60\) 分。
算法四
剩下的部分就是一些牛逼(套路)操作了。
仔细观察,转移用到的生成函数除了 \(a_i\),其它部分都很相似,我们不妨设
那么有
简单推导可以得到
把 \(G_n(x)\) 的表达式写得好一点是
显然对于某个多项式 \(F(x)\),求 \(\sum_{i=1}^nF(a_ix)\) 比求 \(\prod_{i=1}^nF(a_ix)\) 容易得多,我们考虑先求 ln 再求 exp
整理一下,答案就是
现在的问题转化为,对于一个多项式 \(F(x)\),求 \(\sum_{i=1}^n F(a_ix)\)。
因为是求和,我们可以写成
那么现在的问题就是,对于每个 \(i\),求出 \(\sum_{j=1}^na_j^i\)。
众所周知,\(\frac{1}{1-ax}=\sum_{i\geq0}a^ix^i\),因此上面的问题可以有如下转化
这是个经典问题。因为问题规模不允许我们对于每个 \(1-a_jx\) 求逆后相加,所以我们考虑直接从分式入手。我们尝试分治这个和式,然后合并两边的分式的时候,就模拟分式通分后相加的过程。
这样能保证分治的时候,该区间的多项式次数为该区间长度,从而保证复杂度。
至此我们就解决了这个问题,时间复杂度 \(\mathcal O(n\log^2n+n\log m)\)。所以 \(m\) 其实可以出到 \(10^{18}\)。
注意特判 \(n=1\),否则你会在 UOJ 上获得 97 分的好分数,别问我是怎么知道的。
#include <bits/stdc++.h>
template <class T>
inline void read(T &x)
{
static char ch;
while (!isdigit(ch = getchar()));
x = ch - '0';
while (isdigit(ch = getchar()))
x = x * 10 + ch - '0';
}
const int mod = 998244353;
inline int qpow(int x, int y)
{
int res = 1;
for (; y; y >>= 1, x = 1LL * x * x % mod)
if (y & 1)
res = 1LL * res * x % mod;
return res;
}
inline void add(int &x, const int &y)
{
x += y;
if (x >= mod)
x -= mod;
}
inline void dec(int &x, const int &y)
{
x -= y;
if (x < 0)
x += mod;
}
typedef std::vector<int> vi;
typedef std::pair<vi, vi> pvi;
#define mp(x, y) std::make_pair(x, y)
const int MaxN = 2e5 + 5;
const int INF = 0x3f3f3f3f;
int fac[MaxN], fac_inv[MaxN], pwm[MaxN], ind[MaxN];
inline void fac_init(int n)
{
ind[1] = 1;
for (int i = 2; i <= n; ++i)
ind[i] = 1LL * ind[mod % i] * (mod - mod / i) % mod;
fac[0] = 1;
for (int i = 1; i <= n; ++i)
fac[i] = 1LL * fac[i - 1] * i % mod;
fac_inv[n] = qpow(fac[n], mod - 2);
for (int i = n - 1; i >= 0; --i)
fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
}
namespace polynomial
{
int P, L;
int rev[MaxN];
inline void DFT_init(int n)
{
P = 0, L = 1;
while (L < n)
L <<= 1, ++P;
for (int i = 1; i < L; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (P - 1));
}
inline void DFT(vi &a, int n, int opt)
{
for (int i = 0; i < n; ++i)
if (i < rev[i])
std::swap(a[i], a[rev[i]]);
int g = opt == 1 ? 3 : (mod + 1) / 3;
for (int k = 1; k < n; k <<= 1)
{
int omega = qpow(g, (mod - 1) / (k << 1));
for (int i = 0; i < n; i += k << 1)
{
int x = 1;
for (int j = 0; j < k; ++j)
{
int u = a[i + j];
int v = 1LL * a[i + j + k] * x % mod;
add(a[i + j] = u, v);
dec(a[i + j + k] = u, v);
x = 1LL * x * omega % mod;
}
}
}
if (opt == -1)
{
int inv = ind[n];
for (int i = 0; i < n; ++i)
a[i] = 1LL * a[i] * inv % mod;
}
}
inline vi plus(vi a, vi b)
{
int sze = std::max(a.size(), b.size());
a.resize(sze), b.resize(sze);
for (int i = 0; i < sze; ++i)
add(a[i], b[i]);
return a;
}
inline vi mul(vi a, vi b, int lim = INF)
{
int sze = a.size() + b.size() - 1;
DFT_init(sze), a.resize(L, 0), b.resize(L, 0);
vi c(L);
DFT(a, L, 1), DFT(b, L, 1);
for (int i = 0; i < L; ++i)
c[i] = 1LL * a[i] * b[i] % mod;
DFT(c, L, -1);
return c.resize(std::min(sze, lim)), c;
}
inline vi inverse(vi a)
{
int n = a.size(), m = 1;
vi b(1, qpow(a[0], mod - 2)), ta;
while (m < n)
{
m <<= 1;
DFT_init(m << 1);
b.resize(L, 0);
(ta = a).resize(m);
ta.resize(L, 0);
DFT(b, L, 1), DFT(ta, L, 1);
for (int i = 0; i < L; ++i)
b[i] = 1LL * b[i] * (mod + 2 - 1LL * ta[i] * b[i] % mod) % mod;
DFT(b, L, -1);
b.resize(m, 0);
}
return b.resize(n), b;
}
inline vi derivative(vi a)
{
vi res(0);
for (int i = 1, lim = a.size(); i < lim; ++i)
res.push_back(1LL * i * a[i] % mod);
return res;
}
inline vi anti_derivative(vi a)
{
vi res(1, 0);
for (int i = 0, lim = a.size(); i < lim; ++i)
res.push_back(1LL * a[i] * ind[i + 1] % mod);
return res;
}
inline vi ln(vi a)
{
return anti_derivative(mul(derivative(a), inverse(a), a.size() - 1));
}
inline vi exp(vi a)
{
int n = a.size(), m = 1;
vi b(1, 1), ta;
while (m < n)
{
m <<= 1;
b.resize(m, 0);
vi ln_b = ln(b);
(ta = a).resize(m);
add(ta[0], 1);
for (int i = 0; i < m; ++i)
dec(ta[i], ln_b[i]);
b = mul(b, ta, m);
}
return b.resize(n), b;
}
}
vi sum;
int n, m;
int a[MaxN];
inline pvi solve(int l, int r)
{
using namespace polynomial;
if (l == r)
{
vi t(1, 1); t.push_back(mod - a[l]);
return mp(vi(1, 1), t);
}
int mid = (l + r) >> 1;
pvi lef = solve(l, mid), rit = solve(mid + 1, r);
return mp(plus(mul(lef.first, rit.second), mul(rit.first, lef.second)), mul(lef.second, rit.second));
}
inline vi get_sum(vi a)
{
vi res(0); int n = a.size();
for (int i = 0; i < n; ++i)
res.push_back(1LL * a[i] * sum[i] % mod);
return res;
}
int main()
{
read(n), read(m), fac_init(MaxN - 1);
for (int i = 0; i <= (n << 1); ++i)
pwm[i] = qpow(i, m);
int prod = 1;
for (int i = 1; i <= n; ++i)
{
read(a[i]);
prod = 1LL * prod * a[i] % mod;
}
if (n == 1)
return puts(m ? "0" : "1"), 0;
using namespace polynomial;
pvi t = solve(1, n);
sum = mul(t.first, inverse(t.second), n - 1);
vi A(0), B(0);
for (int i = 0; i < n - 1; ++i)
{
A.push_back(1LL * pwm[i + 1] * fac_inv[i] % mod);
B.push_back(1LL * pwm[i + 1] * pwm[i + 1] % mod * fac_inv[i] % mod);
}
B = get_sum(mul(B, inverse(A), n - 1));
A = exp(get_sum(ln(A)));
int res = mul(A, B)[n - 2];
std::cout << 1LL * fac[n - 2] * prod % mod * res % mod << '\n';
return 0;
}