【洛谷2624_BZOJ1005】[HNOI2008] 明明的烦恼(Prufer序列_高精度_组合数学)
题目:
分析:
本文中所有的 “树” 都是带标号的。
介绍一种把树变成一个序列的工具:Prufer 序列。
对于一棵 \(n\) 个结点的树,每次选出一个叶子(度数为 \(1\) 的结点),将唯一的那个与它相连的点标号加入 Prufer 序列末尾,然后删去这个叶子及其所连的边,直到最后剩下两个点和一条边。由于每次删且仅删一个点和一条边,所以 Prufer 序列长度为 \(n-2\) 。点 \(a\) 在序列中每次出现都意味着一条与它相连的边被删去了,一直删到 \(a\) 度数为 \(1\) (叶子)为止。所以任意一点在 Prufer 序列中的出现次数就是它的度数减 \(1\) 。
再扯远点。把一个 Prufer 序列还原成树的方式是:定义集合 \(V=\{i|1\leq i \leq n\}\) 每次选出 \(V\) 中编号最小的没有在当前序列出现过的点 \(u\),与 Prufer 序列第一个元素 \(v\) 相连,然后从 \(V\) 中删除 \(u\),并删除 Prufer 序列的第一个元素。最后 \(V\) 中剩余两个元素,将它们相连。
按照上述方法,任意一个由整数 \(1\) 到 \(n\) 组成的序列都可以生成一棵 \(n\) 个结点的树,且任意一棵 \(n\) 个结点的树也可以生成一个这样的序列,所以序列和树一一对应。由此可得推论:\(n\) 个结点的树有 \(n^{n-2}\) 种。
从 Prufer 序列的角度考虑可以解决很多树计数的问题。回到本题。设 \(sum=\sum_{d_i\neq -1} di\) 和 \(num=\sum_{d_i\neq -1} 1\) (即:有 \(num\) 个结点被指定度数,它们的度数之和为 \(sum\) ),则问题被转化为:求一个长为 \(n-2\) 个序列,使其中对于任意 \(i(d_i\neq -1)\) 满足 \(i\) 的出现次数是 \(d_i-1\) 。首先从序列种选出 \(sum-num\) 个位置来放置这些有限制的数,根据可重集合排列公式 \(\frac{(\sum a_i)!}{\prod (a_i!)}\) ,方案数是:
然后剩下 \(n-2-(sum-num)\) 个位置可以从 \(n-num\) 个数中任取,所以最终答案是:
很不幸发现这题没模数,也看不懂什么分解质因数的做法,乖乖写高精度吧……(其实写这篇博客的主要目的是记录高精度板子qwq
代码:
注意一定要把\(\prod_{d_i\neq -1}((d_i-1)!)\)先全部乘起来再除!!!
算 \(n\) 遍除法会 TLE !!!
(我就不幸因为这个 T 了一发……
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
using namespace std;
namespace zyt
{
typedef long long ll;
const int N = 1e3 + 10;
class long_long_long
{
private:
typedef long_long_long Tril;
const static int DIGIT = 4, BASE = 10000, LEN = (3e3 + 100) / DIGIT;
int data[LEN], len;
public:
long_long_long(const int x = 0)
{
memset(data, 0, sizeof(data));
len = 0;
data[0] = x;
while (data[len])
{
data[len + 1] = data[len] / BASE;
data[len++] %= BASE;
}
}
bool operator == (const Tril &b) const
{
if (len != b.len)
return false;
for (int i = 0; i < len; i++)
if (data[i] != b.data[i])
return false;
return true;
}
bool operator < (const Tril &b) const
{
if (len != b.len)
return len < b.len;
for (int i = len - 1; i >= 0; i--)
if (data[i] != b.data[i])
return data[i] < b.data[i];
return false;
}
Tril operator + (const Tril &b) const
{
Tril ans;
ans.len = max(len, b.len);
for (int i = 0; i < ans.len; i++)
{
ans.data[i] += data[i] + b.data[i];
ans.data[i + 1] += ans.data[i] / BASE;
ans.data[i] %= BASE;
}
if (ans.data[ans.len])
++ans.len;
return ans;
}
Tril operator - (const Tril &b) const
{
Tril ans;
ans.len = len;
for (int i = 0; i < len; i++)
{
ans.data[i] += data[i] - b.data[i];
if (ans.data[i] < 0)
{
--ans.data[i + 1];
ans.data[i] += BASE;
}
}
while (ans.len && !ans.data[ans.len - 1])
--ans.len;
return ans;
}
Tril operator * (const Tril &b) const
{
static ll tmp[LEN];
Tril ans;
if (*this == 0 || b == 0)
return ans;
ans.len = len + b.len - 1;
memset(tmp, 0, sizeof(ll[len + b.len + 2]));
for (int i = 0; i < len; i++)
for (int j = 0; j < b.len; j++)
tmp[i + j] += data[i] * b.data[j];
for (int i = 0; i < ans.len; i++)
{
tmp[i + 1] += tmp[i] / BASE;
ans.data[i] = tmp[i] % BASE;
}
while (tmp[ans.len])
{
tmp[ans.len + 1] += tmp[ans.len] / BASE;
ans.data[ans.len] = tmp[ans.len] % BASE;
++ans.len;
}
return ans;
}
Tril operator / (const Tril &b) const
{
Tril ans, rest;
for (int i = len - 1; i >= 0; i--)
{
rest = rest * BASE + data[i];
if (rest < b)
continue;
int l = 1, r = BASE - 1, anss;
while (l <= r)
{
int mid = (l + r) >> 1;
if (rest < b * mid)
r = mid - 1;
else
l = mid + 1, anss = mid;
}
ans.data[i] = anss;
rest = rest - b * anss;
if (!ans.len)
ans.len = i + 1;
}
return ans;
}
void print()
{
if (!len)
printf("0");
else
{
printf("%d", data[len - 1]);
for (int i = len - 2; i >= 0; i--)
printf("%0*d", DIGIT, data[i]);
}
}
};
typedef long_long_long Tril;
Tril ans, fac[N];
int n;
Tril C(const int n, const int m)
{
return fac[n] / fac[m] / fac[n - m];
}
Tril power(Tril a, int b)
{
Tril ans = 1;
while (b)
{
if (b & 1)
ans = ans * a;
a = a * a;
b >>= 1;
}
return ans;
}
int d[N];
int work()
{
int sum = 0, num = 0;
scanf("%d", &n);
fac[0] = 1;
for (int i = 1; i <= n; i++)
{
fac[i] = fac[i - 1] * i;
scanf("%d", &d[i]);
if (d[i] == -1)
continue;
sum += d[i], ++num;
}
ans = fac[sum - num];
for (int i = 1; i <= n; i++)
if (d[i] > 0)
ans = ans / fac[d[i] - 1];
ans = ans * C(n - 2, sum - num);
ans = ans * power(n - num, n - 2 - (sum - num));
ans.print();
return 0;
}
}
int main()
{
return zyt::work();
}