「51Nod 1601」完全图的最小生成树计数 「Trie」

题意

给定\(n\)个带权点,第\(i\)个点的权值为\(w_i\),任意两点间都有边,边权为两端点权的异或值,求最小生成树边权和,以及方案数\(\bmod 10^9 + 7\)

\(n \leq 10^5,W = max(w_i) \leq 2^{30}\)

题解

考虑按位贪心,我们从高到低考虑二进制第k位。每次把当前点集\(S\)分成第\(k\)位为\(0\)和第\(k\)位为\(1\)的两个集合,记为\(S_0, S_1\)

我们递归下去把这两个集合连成生成树,然后再找一条最小的跨集合的边把这两个集合连通。

考虑这么做为啥对:假设有两条跨集合的边,我删去一条,树变成两个部分。然后任意找到一条集合内部边使集合\(S\)连通(既然有跨集合的边存在,我们一定能找到这样的一条边),这样显然更优。

然后考虑问题:找到\(x\in S_0,y\in S_1,x\text{ xor } y\)最小。

这个用类似线段树合并的方法:每次两个结点同时往下走,尽量往一边走。如果能同时往\(0/1\)走,都走一遍,复杂度是对的,每次合并复杂度是子树大小。考虑trie树上一个点只有\(O(\log W)\)个祖先,一共只有\(O(n \log W)\)个结点,所以复杂度\(O(n \log ^2 W)\)

我们再来考虑方案。叶子结点时假设大小为\(n\),也就是说\(n\)个点都是这个权值,生成树的方案数\(n^{n-2}\)(由prufer序列得)。非叶子结点时,方案是分成的两个集合的方案乘最后连边方案。连边会对应trie树上多对叶子\((u, v)\)(这些对结点异或起来都是最小的),若叶子\(u\)上放的数个数用\(cnt[u]\)表示,连边方案就是\(\sum_{(u,v)} cnt[u]*cnt[v]\)

P.S.:快速幂写错了调了好久,差评

#include <algorithm>
#include <cstdio>
using namespace std;
typedef long long ll;
char gc() {
	static char buf[1 << 20], * S, * T;
	if(S == T) {
		T = (S = buf) + fread(buf, 1, 1 << 20, stdin);
		if(S == T) return EOF;
	}
	return *S ++;
}
template<typename T> void read(T &x) {
	x = 0; char c = gc(); bool bo = 0;
	for(; c > '9' || c < '0'; c = gc()) bo |= c == '-';
	for(; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15);
	if(bo) x = -x;
}
const int N = 1e5 + 10;
const int mo = 1e9 + 7;
int n, id = 1, ch[N * 30][2], cnt[N * 30], w[N * 30];
void insert(int x) {
	int u = 1;
	for(int i = 29; ~ i; i --) {
		int y = x >> i & 1;
		if(!ch[u][y]) {
			ch[u][y] = ++ id;
			w[id] = y << i;
		}
		u = ch[u][y];
	}
	cnt[u] ++;
}
ll ans;
int ans2, tot2, tot = 1;
int qpow(int a, int b) {
	int ans = 1;
	for(; b >= 1; b >>= 1, a = (ll) a * a % mo)
		if(b & 1) ans = (ll) ans * a % mo;
	return ans;
}
void merge(int u, int v, int now) {
	now ^= w[u] ^ w[v];
	if(cnt[u] && cnt[v]) {
		if(now < ans2) { ans2 = now; tot2 = 0; }
		if(now == ans2) tot2 = (tot2 + (ll) cnt[u] * cnt[v]) % mo;
		return ;
	}
	bool tag = 0;
	if(ch[u][0] && ch[v][0]) merge(ch[u][0], ch[v][0], now), tag = 1;
	if(ch[u][1] && ch[v][1]) merge(ch[u][1], ch[v][1], now), tag = 1;
	if(tag) return ;
	if(ch[u][0] && ch[v][1]) merge(ch[u][0], ch[v][1], now);
	if(ch[u][1] && ch[v][0]) merge(ch[u][1], ch[v][0], now);
}
bool solve(int u) {
	if(!u) return 0;
	if(cnt[u]) {
		if(cnt[u] > 2) tot = (ll) tot * qpow(cnt[u], cnt[u] - 2) % mo;
		return 1;
	}
	bool s = solve(ch[u][1]) & solve(ch[u][0]);
	if(s) {
		ans2 = 2e9 + 10; tot2 = 1;
		merge(ch[u][0], ch[u][1], 0);
		ans += ans2; tot = (ll) tot * tot2 % mo;
	}
	return 1;
}
int main() {
	read(n);
	for(int i = 1; i <= n; i ++) {
		int x; read(x); insert(x);
	}
	solve(1);
	printf("%lld\n%d\n", ans, tot);
	return 0;
}
posted @ 2019-10-11 18:06  hfhongzy  阅读(217)  评论(0编辑  收藏  举报