【BZOJ1016】【JSOI2008】最小生成树计数
Description
现在给出了一个简单无向加权图。你不满足于求出这个图的最小生成树,而希望知道这个图中有多少个不同的最小生成树。(如果两颗最小生成树中至少有一条边不同,则这两个最小生成树就是不同的)。由于不同的最小生成树可能很多,所以你只需要输出方案数对31011的模就可以了。
Input
第一行包含两个数,n和m,其中1<=n<=100; 1<=m<=1000; 表示该无向图的节点数和边数。每个节点用1~n的整数编号。接下来的m行,每行包含两个整数:a, b, c,表示节点a, b之间的边的权值为c,其中1<=c<=1,000,000,000。数据保证不会出现自回边和重边。注意:具有相同权值的边不会超过10条。
Output
输出不同的最小生成树有多少个。你只需要输出数量对31011的模就可以了。
Sample Input
4 6
1 2 1
1 3 1
1 4 1
2 3 2
2 4 1
3 4 1
Sample Output
8
Solution
对于不同的MST方案,相同权值的边数总是一定。因此我们可以排序后先跑一次MST并离散权值,统计出每种权值被用了多少次,然后对于每种权值暴力枚举各条边是否使用,检查合法性以及使用边数是否等于原本所使用的,乘法原理一下即可,由于有撤销操作不能够压缩路径,因此建议采用启发式合并。
时间复杂度\(O(E \log_2 E + \Sigma 2^{cnt_{v}} * \log_{2} V)\)其中\(cnt_{v}\)表示权值v在MST中的使用次数。
Code
#include <stdio.h>
#include <algorithm>
#define MN 105
#define MM 1005
#define R register
#define mod 31011
inline int read(){
R int x; R bool f; R char c;
for (f=0; (c=getchar())<'0'||c>'9'; f=c=='-');
for (x=c-'0'; (c=getchar())>='0'&&c<='9'; x=(x<<3)+(x<<1)+c-'0');
return f?-x:x;
}
int V,E,cnt,val[MM],x[MM],y[MM],rk[MM],l[MM],r[MM],v[MM],fa[MN],sz[MN],sum,ans=1,k;
inline bool cmp(int x,int y){return val[x]<val[y];}
inline int find(int x){return fa[x]==x?x:find(fa[x]);}
inline void swap(int &x,int &y){x^=y,y^=x,x^=y;}
inline void ins(int x,int y){
if (sz[x]<sz[y]) swap(x,y);
fa[y]=x;sz[x]+=sz[y];
}
inline void del(int x,int y){
if (fa[x]==y) swap(x,y);
fa[y]=y; sz[x]-=sz[y];
}
inline void dfs(int t,int no,int k){
if (no>r[t]){
sum+=(k==v[t]);
return;
}R int p=find(x[rk[no]]),q=find(y[rk[no]]);
if (p!=q){
ins(p,q);dfs(t,no+1,k+1);del(p,q);
}dfs(t,no+1,k);
}
int main(){
V=read(),E=read();for (R int i=1; i<=E; ++i) x[i]=read(),y[i]=read(),val[i]=read(),rk[i]=i;
std::sort(rk+1,rk+E+1,cmp);for (R int i=1; i<=V; ++i) fa[i]=i;
for (R int i=1; i<=E; ++i){
if (val[rk[i]]!=val[rk[i-1]]) {r[cnt]=i-1; if (k==V-1) break;l[++cnt]=i;}
R int p=find(x[rk[i]]),q=find(y[rk[i]]);
if (p!=q){ins(p,q);++v[cnt],++k;}
}if (!r[cnt]) r[cnt]=E;if (k!=V-1){puts("0");return 0;}for (R int i=1; i<=V; ++i) fa[i]=i;
for (R int i=1; i<=cnt; ++i){
sum=0;dfs(i,l[i],0);ans=ans*sum%mod;
for (R int j=l[i]; j<=r[i]; ++j){
R int p=find(x[rk[j]]),q=find(y[rk[j]]);
if (p!=q) ins(p,q);
}
}printf("%d\n",ans);
return 0;
}