BZOJ 1016 最小生成树计数(矩阵树定理)
我们把边从小到大排序,然后依次插入一种权值的边,然后把每一个联通块合并。
然后当一次插入的边不止一条时做矩阵树定理就行了。算出有多少种生成树就行了。
剩下的交给乘法原理。
实现一不小心就会让程序变得很丑
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
#define int long long
const int mod=31011;
const int N=110;
int fa[N],a[N][N][N],n,m,b[1010],cnt[N],id[N],w[N],ans[N],mmp[N];
struct edge{
int u,v,w;
}e[1010];
bool cmp(edge a,edge b){
return a.w<b.w;
}
int find(int x){
if(fa[x]==x)return x;
else return fa[x]=find(fa[x]);
}
int gauss(int x,int n){
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
a[x][i][j]=(a[x][i][j]+mod)%mod;
int f=1,ans=1;
for(int i=1;i<=n;i++){
for(int j=i+1;j<=n;j++){
int A=a[x][i][i],B=a[x][j][i];
while(B){
int t=A/B;A%=B;swap(A,B);
for(int k=i;k<=n;k++)a[x][i][k]=(a[x][i][k]-t*a[x][j][k]%mod+mod)%mod;
for(int k=i;k<=n;k++)swap(a[x][i][k],a[x][j][k]);
f=-f;
}
}
ans=ans*a[x][i][i]%mod;
}
memset(a[x],0,sizeof(a[x]));
return (ans*f+mod)%mod;;
}
int read(){
int sum=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
return sum*f;
}
void init(){
for(int j=1;j<=n;j++)cnt[j]=0;
for(int j=1;j<=n;j++)fa[j]=j;
}
signed main(){
n=read(),m=read();
for(int i=1;i<=m;i++)e[i].u=read(),e[i].v=read(),e[i].w=read(),b[i]=e[i].w;
sort(b+1,b+1+m);
int num=unique(b+1,b+1+m)-b-1;
for(int i=1;i<=m;i++)e[i].w=lower_bound(b+1,b+1+num,e[i].w)-b;
sort(e+1,e+1+m,cmp);
int now=1;
for(int i=1;i<=n;i++)ans[i]=1;
for(int i=1;i<=num;i++){
init();
int line=now,tmp=0;
while(line<=m&&e[line].w==i){
if(e[line].u==e[line].v){line++;continue;}
int x=find(e[line].u),y=find(e[line].v);
if(x!=y)fa[x]=y;
line++;
}
for(int j=1;j<=n;j++)fa[j]=find(j);
for(int j=1;j<=n;j++)id[j]=++cnt[fa[j]];
for(int j=now;j<=line-1;j++){
if(e[j].u==e[j].v)continue;
a[fa[e[j].u]][id[e[j].u]][id[e[j].u]]++;
a[fa[e[j].v]][id[e[j].v]][id[e[j].v]]++;
a[fa[e[j].u]][id[e[j].u]][id[e[j].v]]--;
a[fa[e[j].u]][id[e[j].v]][id[e[j].u]]--;
}
for(int j=1;j<=n;j++)w[j]=1;
for(int j=1;j<=n;j++)w[fa[j]]=w[fa[j]]*ans[j]%mod;
for(int j=1;j<=n;j++)
if(cnt[j]){
mmp[j]=++tmp;
if(cnt[j]==1)ans[tmp]=w[j];
else ans[tmp]=w[j]*gauss(j,cnt[j]-1)%mod;
}
for(int j=line;j<=m;j++)e[j].u=mmp[fa[e[j].u]],e[j].v=mmp[fa[e[j].v]];
now=line;n=tmp;
}
if(n>1)printf("0");
else printf("%lld",ans[1]);
return 0;
}