BZOJ 1016 最小生成树计数
题目链接
https://www.lydsy.com/JudgeOnline/problem.php?id=1016
相关博客
https://blog.csdn.net/sdfzyhx/article/details/52075151
kur产生的最小生成树只是一组解,但不一定是唯一的解。
(具体分析不写了)有一个结论,任意最小生成树中,固定边权选取的数量是一定的,只是相同的边权,选取的方式不同,造成的多个解。
然后对于相同的边权,子集枚举满足的方案数数量。最后根据乘法原理求解。
AC代码
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> using namespace std; const int mxn=10000; int n,m; int sum; int ans=1; // struct edge{//边 int x,y; int v; }e[mxn]; struct segment{ int st,ed;//区块起止点 int v; }se[mxn]; int cnt; int cmp(const edge a,const edge b){ return a.v<b.v; } // int fa[mxn]; int find(int x){ if(fa[x]==x)return x; return find(fa[x]);//不可压缩 } // void dfs(int x,int now,int t){//x 组编号 now现在处理的边编号 t使用的边编号 if(now==se[x].ed+1){ if(t==se[x].v)sum++; return; } int u=find(e[now].x),v=find(e[now].y); if(u!=v){ fa[u]=v; dfs(x,now+1,t+1);//选用这条边 fa[u]=u;fa[v]=v;//还原状态 } dfs(x,now+1,t);//不选这条边 return; } int main(){ scanf("%d%d",&n,&m); int i,j; for(i=1;i<=n;i++)fa[i]=i;//初始化并查集,处理边的连通 for(i=1;i<=m;i++) scanf("%d%d%d",&e[i].x,&e[i].y,&e[i].v); sort(e+1,e+m+1,cmp); int tot=0;//联通边数 for(i=1;i<=m;i++){ if(e[i].v!=e[i-1].v){//如果权值与之前不同 se[cnt].ed=i-1;se[++cnt].st=i;//分到新的一组 } int u=find(e[i].x); int v=find(e[i].y); if(u!=v){fa[u]=v;se[cnt].v++;tot++;} } se[cnt].ed=m; if(tot!=n-1){printf("0");return 0;}//未联通 for(i=1;i<=n;i++)fa[i]=i;//初始化并查集,处理边组的连通 for(i=1;i<=cnt;i++){ sum=0; dfs(i,se[i].st,0); ans=(ans*sum)%31011; for(j=se[i].st;j<=se[i].ed;j++){ int u=find(e[j].x),v=find(e[j].y); if(u!=v)fa[u]=v; } } printf("%d",ans%31011); return 0; }
当重复的边多了只能缩点用矩阵数定理来写了。
#include<iostream> #include<algorithm> #include<cstring> #include<vector> #include<cstdio> #define p 31011 #define N 1003 using namespace std; int a[12][12],c[N][N],n,m,vis[N],fa[N],U[N]; vector<int> V[N]; struct data { int x,y,c; bool operator<(const data &a)const { return c<a.c; } } e[N]; int gauss(int n) { for (int i=1; i<=n; i++) for (int j=1; j<=n; j++) a[i][j]%=p; //for (int i=1;i<=n;i++,cout<<endl) // for (int j=1;j<=n;j++) cout<<a[i][j]<<" "; int ret=1; for (int i=1; i<=n; i++) { int num=i; for (int j=i+1; j<=n; j++) if (abs(a[j][i])) num=j; for (int j=1; j<=n; j++) swap(a[num][j],a[i][j]); if (num!=i) ret=-ret; for (int j=i+1; j<=n; j++) while (a[j][i]) { int t=a[j][i]/a[i][i]; for (int k=1; k<=n; k++) a[j][k]=(a[j][k]-t*a[i][k])%p; if (!a[j][i]) break; ret=-ret; for (int k=1; k<=n; k++) swap(a[i][k],a[j][k]); } ret=(ret*a[i][i])%p; } //cout<<ret<<endl; return (ret%p+p)%p; } int find(int x,int f[N]) { if (x==f[x]) return x; else return find(f[x],f); } int main() { //freopen("bzoj_1016.in","r",stdin); // freopen("bzoj_1016.out","w",stdout); scanf("%d%d",&n,&m); for (int i=1; i<=m; i++) scanf("%d%d%d",&e[i].x,&e[i].y,&e[i].c); sort(e+1,e+m+1); for (int i=1; i<=n; i++) vis[i]=0,fa[i]=i; int ans=1; int last=-1; for (int i=1; i<=m+1; i++) { if (e[i].c!=last||i==m+1) { for (int j=1; j<=n; j++) if (vis[j]) { int r1=find(j,U); V[r1].push_back(j); vis[j]=0; } for (int j=1; j<=n; j++) if (V[j].size()>1) { memset(a,0,sizeof(a)); int len=V[j].size(); for (int k=0; k<len; k++) for (int l=k+1; l<len; l++) { int x=V[j][k]; int y=V[j][l]; int t=c[x][y]; a[k+1][l+1]-=t; a[l+1][k+1]-=t; a[k+1][k+1]+=t; a[l+1][l+1]+=t; } ans=ans*gauss(len-1)%p; for (int k=0; k<len; k++) fa[V[j][k]]=j; } for (int j=1; j<=n; j++) { U[j]=fa[j]=find(j,fa); V[j].clear(); } last=e[i].c; if(i==m+1) break; } int x=e[i].x; int y=e[i].y; int r1=find(x,fa); int r2=find(y,fa); if (r1==r2) continue; U[find(r1,U)]=find(r2,U); vis[r1]=1; vis[r2]=1; c[r1][r2]++; c[r2][r1]++; } int flag=1; for (int i=2; i<=n; i++) if (find(i,U)!=find(i-1,U)) flag=0; ans=(ans*flag%p+p)%p; printf("%d\n",ans); }