BZOJ1016: [JSOI2008]最小生成树计数
首先一定要清楚以下定理,一个无向图所有的最小生成树中某种权值的边的数目均相同。
引用一篇大牛的证明:
我们证明以下定理:一个无向图所有的最小生成树中某种权值的边的数目均相同。
开始时,每个点单独构成一个集合。
首先只考虑权值最小的边,将它们全部添加进图中,并去掉环,由于是全部尝试添加,那么只要是用这种权值的边能够连通的点,最终就一定能在一个集合中。
那么不管添加的是哪些边,最终形成的集合数都是一定的,且集合的划分情况一定相同。那么真正添加的边数也是相同的。因为每添加一条边集合的数目便减少1.
那么权值第二小的边呢?我们将之间得到的集合每个集合都缩为一个点,那么权值第二小的边就变成了当前权值最小的边,也有上述的结论。
因此每个阶段,添加的边数都是相同的。我们以权值划分阶段,那么也就意味着某种权值的边的数目是完全相同的。
知道这个定理,就可以先用kruskal计算出一棵最小生成树,然后记录最小生成树上每个权值对应的不同边有几条。
然后根据证明过程的每个阶段,穷举所有的方案,看是否能在制定次数内全部加入森林,记录每个阶段解决方案数,最后用乘法计数原理,乘起来就是答案。
这份代码就看着黄学长的代码风格写的
/************************************************************** Problem: 1016 User: 96655 Language: C++ Result: Accepted Time:8 ms Memory:1516 kb ****************************************************************/ #include<cstdio> #include<algorithm> #include<iostream> #include<cstring> #include<vector> #include<stack> #include<cmath> #include<queue> #include<map> using namespace std; int n,m,cnt,tot,ans,sum; const int maxn=10005; const int mod=31011; int fa[maxn/10]; struct Edge { int u,v,w; bool operator<(const Edge &h)const { return w<h.w; } } e[maxn]; struct data { int l,r,c; } a[maxn]; inline int read() { int x=0; char ch=getchar(); while(ch<'0'||ch>'9') { ch=getchar(); } while(ch>='0'&&ch<='9') { x=x*10+ch-'0'; ch=getchar(); } return x; } int find(int x) { if(x==fa[x])return x; return find(fa[x]); } void dfs(int pos,int now,int k) { if(now==a[pos].r+1) { if(k==a[pos].c)++sum; return; } int fx=find(e[now].u); int fy=find(e[now].v); if(fx!=fy) { fa[fx]=fy; dfs(pos,now+1,k+1); fa[fx]=fx; fa[fy]=fy; } dfs(pos,now+1,k); } int main() { n=read(); m=read(); for(int i=1; i<=n; i++) fa[i]=i; for(int i=1; i<=m; i++) e[i].u=read(),e[i].v=read(),e[i].w=read(); sort(e+1,e+1+m); tot=cnt=0; ans=1; for(int i=1; i<=m; i++) { if(e[i].w!=e[i-1].w) { a[++cnt].l=i; a[cnt-1].r=i-1; } int fx=find(e[i].u); int fy=find(e[i].v); if(fx!=fy) { fa[fx]=fy; a[cnt].c++; tot++; } } a[cnt].r=m; if(tot!=n-1) { printf("0\n"); return 0; } for(int i=1; i<=n; i++) fa[i]=i; for(int i=1; i<=cnt; i++) { sum=0; dfs(i,a[i].l,0); ans=(ans*sum)%mod; for(int j=a[i].l; j<=a[i].r; j++) { int fx=find(e[j].u); int fy=find(e[j].v); if(fx!=fy)fa[fx]=fy; } } printf("%d\n",ans); return 0; }