BZOJ 1016 [JSOI2008]最小生成树计数 ——Matrix-Tree定理
考虑从小往大加边,然后把所有联通块的生成树个数计算出来。
然后把他们缩成一个点,继续添加下一组。
最后乘法原理即可。
写起来很恶心
#include <queue> #include <cmath> #include <vector> #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; #define F(i,j,k) for (int i=j;i<=k;++i) #define D(i,j,k) for (int i=j;i>=k;--i) #define maxn 1005 #define eps 1e-6 const int md=31011; vector <int> v,to[maxn]; queue <int> q; struct Edge{int u,v,w;}a[maxn]; int n,m,fa[maxn]; int b[maxn][maxn],inv[maxn]; int vcnt,du[maxn],list[maxn],vis[maxn]; bool cmp(Edge x,Edge y) {return x.w<y.w;} int gf(int k) {if (fa[k]==k) return k; else return fa[k]=gf(fa[k]);} int gauss(int n) { F(i,1,n) F(j,1,n) b[i][j]%=md; int ret=1; for (int i=1;i<n;++i) { for (int j=i+1;j<n;++j) while (b[j][i]) { int t=b[i][i]/b[j][i]; for (int k=i;k<n;++k) b[i][k]=(b[i][k]-b[j][k]*t+md)%md; for (int k=i;k<n;++k) swap(b[i][k],b[j][k]); ret=-ret; } if (b[i][i]==0) return 0; ret=ret*b[i][i]%md; } return abs((ret+md)%md); } int main() { scanf("%d%d",&n,&m); F(i,1,n) fa[i]=i; F(i,1,m){scanf("%d%d%d",&a[i].u,&a[i].v,&a[i].w);} sort(a+1,a+m+1,cmp); int now=1,ans=1; while (now<=m) { int l=now,r=now; vcnt=0; memset(du,0,sizeof du); F(i,1,n) to[i].clear(); while (a[r+1].w==a[r].w) r++; now=r+1; F(i,l,r) { int fl=gf(a[i].u),fr=gf(a[i].v); to[fl].push_back(fr); to[fr].push_back(fl); if (fl!=fr) du[fl]++,du[fr]++; } memset(vis,0,sizeof vis); F(i,1,n) if (du[i]&&!vis[i]) { v.clear(); memset(b,0,sizeof b); memset(inv,0,sizeof inv); q.push(i);inv[i]=1;vis[i]=1; while (!q.empty()) { int x=q.front();v.push_back(x);q.pop(); for (int j=0;j<to[x].size();++j) if (!vis[to[x][j]]) q.push(to[x][j]),inv[to[x][j]]=1,vis[to[x][j]]=1; } for (int j=0;j<v.size();++j) list[v[j]]=j+1; for (int j=0;j<v.size();++j) for (int k=0;k<to[v[j]].size();++k) if (inv[to[v[j]][k]]) { b[list[v[j]]][list[v[j]]]++,b[list[v[j]]][list[to[v[j]][k]]]--; } ans*=gauss(v.size()); ans%=md; } F(i,l,r) { int fl=gf(a[i].u),fr=gf(a[i].v); if (fl!=fr){fa[fl]=fr;} } } int cnt=0; F(i,1,n) if (fa[i]==i) { cnt++; if (cnt==2) {printf("0\n"); return 0;} } printf("%d\n",ans); }