【洛谷2624_BZOJ1005】[HNOI2008] 明明的烦恼(Prufer序列_高精度_组合数学)

题目:

洛谷2624

分析:

本文中所有的 “树” 都是带标号的。

介绍一种把树变成一个序列的工具:Prufer 序列。

对于一棵 \(n\) 个结点的树,每次选出一个叶子(度数为 \(1\) 的结点),将唯一的那个与它相连的点标号加入 Prufer 序列末尾,然后删去这个叶子及其所连的边,直到最后剩下两个点和一条边。由于每次删且仅删一个点和一条边,所以 Prufer 序列长度为 \(n-2\) 。点 \(a\) 在序列中每次出现都意味着一条与它相连的边被删去了,一直删到 \(a\) 度数为 \(1\) (叶子)为止。所以任意一点在 Prufer 序列中的出现次数就是它的度数减 \(1\)

再扯远点。把一个 Prufer 序列还原成树的方式是:定义集合 \(V=\{i|1\leq i \leq n\}\) 每次选出 \(V\) 中编号最小的没有在当前序列出现过的点 \(u\),与 Prufer 序列第一个元素 \(v\) 相连,然后从 \(V\) 中删除 \(u\),并删除 Prufer 序列的第一个元素。最后 \(V\) 中剩余两个元素,将它们相连。

按照上述方法,任意一个由整数 \(1\)\(n\) 组成的序列都可以生成一棵 \(n\) 个结点的树,且任意一棵 \(n\) 个结点的树也可以生成一个这样的序列,所以序列和树一一对应。由此可得推论:\(n\) 个结点的树有 \(n^{n-2}\) 种。

从 Prufer 序列的角度考虑可以解决很多树计数的问题。回到本题。设 \(sum=\sum_{d_i\neq -1} di\)\(num=\sum_{d_i\neq -1} 1\) (即:有 \(num\) 个结点被指定度数,它们的度数之和为 \(sum\) ),则问题被转化为:求一个长为 \(n-2\) 个序列,使其中对于任意 \(i(d_i\neq -1)\) 满足 \(i\) 的出现次数是 \(d_i-1\) 。首先从序列种选出 \(sum-num\) 个位置来放置这些有限制的数,根据可重集合排列公式 \(\frac{(\sum a_i)!}{\prod (a_i!)}\) ,方案数是:

\[C_{n-2}^{sum-num}\frac{(sum-num)!}{\prod_{d_i\neq -1} ((d_i-1)!)} \]

然后剩下 \(n-2-(sum-num)\) 个位置可以从 \(n-num\) 个数中任取,所以最终答案是:

\[C_{n-2}^{sum-num}\frac{(sum-num)!}{\prod_{d_i\neq -1} ((d_i-1)!)}\cdot (n-num)^{(n-2-(sum-num))} \]

很不幸发现这题没模数,也看不懂什么分解质因数的做法,乖乖写高精度吧……(其实写这篇博客的主要目的是记录高精度板子qwq

代码:

注意一定要把\(\prod_{d_i\neq -1}((d_i-1)!)\)先全部乘起来再除!!!

\(n\) 遍除法会 TLE !!!

(我就不幸因为这个 T 了一发……

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
using namespace std;

namespace zyt
{
	typedef long long ll;
	const int N = 1e3 + 10;
	class long_long_long
	{
	private:
		typedef long_long_long Tril;
		const static int DIGIT = 4, BASE = 10000, LEN = (3e3 + 100) / DIGIT;
		int data[LEN], len;
	public:
		long_long_long(const int x = 0)
		{
			memset(data, 0, sizeof(data));
			len = 0;
			data[0] = x;
			while (data[len])
			{
				data[len + 1] = data[len] / BASE;
				data[len++] %= BASE;
			}
		}
		bool operator == (const Tril &b) const
		{
			if (len != b.len)
				return false;
			for (int i = 0; i < len; i++)
				if (data[i] != b.data[i])
					return false;
			return true;
		}
		bool operator < (const Tril &b) const
		{
			if (len != b.len)
				return len < b.len;
			for (int i = len - 1; i >= 0; i--)
				if (data[i] != b.data[i])
					return data[i] < b.data[i];
			return false;
		}
		Tril operator + (const Tril &b) const
		{
			Tril ans;
			ans.len = max(len, b.len);
			for (int i = 0; i < ans.len; i++)
			{
				ans.data[i] += data[i] + b.data[i];
				ans.data[i + 1] += ans.data[i] / BASE;
				ans.data[i] %= BASE;
			}
			if (ans.data[ans.len])
				++ans.len;
			return ans;
		}
		Tril operator - (const Tril &b) const
		{
			Tril ans;
			ans.len = len;
			for (int i = 0; i < len; i++)
			{
				ans.data[i] += data[i] - b.data[i];
				if (ans.data[i] < 0)
				{
					--ans.data[i + 1];
					ans.data[i] += BASE;
				}
			}
			while (ans.len && !ans.data[ans.len - 1])
				--ans.len;
			return ans;
		}
		Tril operator * (const Tril &b) const
		{
			static ll tmp[LEN];
			Tril ans;
			if (*this == 0 || b == 0)
				return ans;
			ans.len = len + b.len - 1;
			memset(tmp, 0, sizeof(ll[len + b.len + 2]));
			for (int i = 0; i < len; i++)
				for (int j = 0; j < b.len; j++)
					tmp[i + j] += data[i] * b.data[j];
			for (int i = 0; i < ans.len; i++)
			{
				tmp[i + 1] += tmp[i] / BASE;
				ans.data[i] = tmp[i] % BASE;
			}
			while (tmp[ans.len])
			{
				tmp[ans.len + 1] += tmp[ans.len] / BASE;
				ans.data[ans.len] = tmp[ans.len] % BASE;
				++ans.len;
			}
			return ans;
		}
		Tril operator / (const Tril &b) const
		{
			Tril ans, rest;
			for (int i = len - 1; i >= 0; i--)
			{
				rest = rest * BASE + data[i];
				if (rest < b)
					continue;
				int l = 1, r = BASE - 1, anss;
				while (l <= r)
				{
					int mid = (l + r) >> 1;
					if (rest < b * mid)
						r = mid - 1;
					else
						l = mid + 1, anss = mid;
				}
				ans.data[i] = anss;
				rest = rest - b * anss;
				if (!ans.len)
					ans.len = i + 1;
			}
			return ans;
		}
		void print()
		{
			if (!len)
				printf("0");
			else
			{
				printf("%d", data[len - 1]);
				for (int i = len - 2; i >= 0; i--)
					printf("%0*d", DIGIT, data[i]);
			}
		}
	};
	typedef long_long_long Tril;
	Tril ans, fac[N];
	int n;
	Tril C(const int n, const int m)
	{
		return fac[n] / fac[m] / fac[n - m];
	}
	Tril power(Tril a, int b)
	{
		Tril ans = 1;
		while (b)
		{
			if (b & 1)
				ans = ans * a;
			a = a * a;
			b >>= 1;
		}
		return ans;
	}
	int d[N];
	int work()
	{
		int sum = 0, num = 0;
		scanf("%d", &n);
		fac[0] = 1;
		for (int i = 1; i <= n; i++)
		{
			fac[i] = fac[i - 1] * i;
			scanf("%d", &d[i]);
			if (d[i] == -1)
				continue;
			sum += d[i], ++num;
		}
		ans = fac[sum - num];
		for (int i = 1; i <= n; i++)
			if (d[i] > 0)
				ans = ans / fac[d[i] - 1];
		ans = ans * C(n - 2, sum - num);
		ans = ans * power(n - num, n - 2 - (sum - num));
		ans.print();
		return 0;
	}
}
int main()
{
	return zyt::work();
}
posted @ 2019-02-24 22:41  Inspector_Javert  阅读(163)  评论(0编辑  收藏  举报