Matrix_tree Theorem 矩阵树定理学习笔记
Matrix_tree Theorem:
给定一个无向图, 定义矩阵A
A[i][j] = - (<i, j>之间的边数)
A[i][i] = 点i的度数
其生成树的个数等于 A的任意n - 1阶主子式的值。
关于定理的相关证明 可以看这篇文章, 讲得非常详细, 耐心看就能看懂:
关于求行列式, 可以用高斯消元。 如果是模域下求行列式, 可以用欧几里得算法。 具体实现看这篇文章
模域下求行列式 模板题:SPOJ DETER3
代码:
1 #include <cstdio>
2 #include <iostream>
3 #include <queue>
4 #include <algorithm>
5 #include <cstring>
6 #include <set>
7 #include <cmath>
8 using namespace std;
9
10 #define N 220
11 #define M 400010
12 typedef long long ll;
13
14 const int Mod=1000000007;
15 const double eps = 1e-9;
16
17 ll Solve(ll a[N][N], int n, ll mod)
18 {
19 ll res = 1;
20 for (int i = 1; i <= n; ++i)
21 {
22 for (int j = i; j <= n; ++j)
23 {
24 if (a[j][i] < 0)
25 {
26 res *= -1;
27 for (int k = i; k <= n; ++k)
28 a[j][k] *= -1;
29 }
30 }
31 int j;
32 for (j = i; j <= n && !a[j][i]; ++j);
33 if (j > n) return 0;
34
35 if (j != i)
36 {
37 res = -res;
38 for (int k = i; k <= n; ++k) swap(a[i][k], a[j][k]);
39 }
40
41 for (j = i + 1; j <= n; ++j)
42 {
43 while (a[j][i])
44 {
45 ll d = a[i][i] / a[j][i];
46 for (int k = i; k <= n; ++k) a[i][k] -= d * a[j][k] % mod, a[i][k] %= mod;
47 for (int k = i; k <= n; ++k) swap(a[i][k], a[j][k]);
48 res = -res;
49 }
50 }
51 res = res * a[i][i] % mod;
52 }
53 if (res < 0) res += mod;
54 return res;
55 }
56
57 int main()
58 {
59 //freopen("in.in","r",stdin);
60 //freopen("out.out","w",stdout);
61
62 int n; ll mod;
63 while (scanf("%d %lld", &n, &mod) != EOF)
64 {
65 ll a[N][N];
66 for (int i = 1; i <= n; ++i)
67 for (int j = 1; j <= n; ++j)
68 scanf("%lld", &a[i][j]);
69 printf("%lld\n", Solve(a, n, mod));
70 }
71
72 return 0;
73 }
下面给出一些应用(练习题):
应用一:SPOJ HIGH
模板题: 给出一个无向图, 求生成树个数。
代码:
1 #include <cstdio>
2 #include <iostream>
3 #include <queue>
4 #include <algorithm>
5 #include <cstring>
6 #include <set>
7 #include <cmath>
8 using namespace std;
9
10 #define N 13
11 #define M 400010
12 typedef long long ll;
13
14 const int Mod=1000000007;
15 const double eps = 1e-9;
16
17 double Solve(int n, double a[N][N])
18 {
19 if (n == 0) return 1;
20
21 /*for (int i = 1; i <= n; ++i)
22 for (int j = 1; j <= n; ++j)
23 printf("%.0lf%c", a[i][j], j == n? '\n':' ');*/
24
25 double res = 1;
26 for (int i = 1; i <= n; ++i)
27 {
28 int j;
29 for (j = i; j <= n && fabs(a[j][i]) < eps; ++j);
30 if (j > n) return 0;
31 if (j != i) for (int k = i; k <= n; ++k) swap(a[i][k], a[j][k]);
32
33 for (int j = i + 1; j <= n; ++j)
34 {
35 double f = a[j][i] / a[i][i];
36 for (int k = i; k <= n; ++k)
37 a[j][k] -= f * a[i][k];
38 }
39 res *= a[i][i];
40 }
41
42 return res;
43 }
44
45 int main()
46 {
47 //freopen("in.in","r",stdin);
48 //freopen("out.out","w",stdout);
49
50 int T, n, m, x, y;
51 scanf("%d", &T);
52 while (T--)
53 {
54 double a[N][N] = {0};
55 scanf("%d %d", &n, &m);
56 for (int i = 1; i <= m; ++i)
57 {
58 scanf("%d %d", &x, &y);
59 a[x][y]--, a[y][x]--;
60 a[x][x]++, a[y][y]++;
61 }
62 printf("%.0lf\n", Solve(n - 1, a));
63 }
64
65 return 0;
66 }
应用二:BZOJ 4031
构图后 求生成树个数 mod 一个数。
1 #include <cstdio>
2 #include <iostream>
3 #include <queue>
4 #include <algorithm>
5 #include <cstring>
6 #include <set>
7 #include <cmath>
8 using namespace std;
9
10 #define N 120
11 #define M 400010
12 typedef long long ll;
13
14 const int Mod=1000000007;
15 const double eps = 1e-9;
16
17 ll Solve(ll a[N][N], int n, ll mod)
18 {
19 if (n == 0) return 1;
20 ll res = 1;
21 for (int i = 1; i <= n; ++i)
22 {
23 //for (int p = 1; p <= n; ++p)
24 // for (int q = 1; q <= n; ++q)
25 // printf("%lld%c", a[p][q], q == n? '\n':' ');
26
27 for (int j = i; j <= n; ++j)
28 {
29 if (a[j][i] < 0)
30 {
31 res = -res;
32 for (int k = i; k <= n; ++k) a[j][k] *= -1;
33 }
34 }
35
36 int j;
37 for (j = i; j <= n && !a[j][i]; ++j);
38 if (j > n) return 0;
39
40 if (j != i)
41 {
42 res = -res;
43 for (int k = i; k <= n; ++k) swap(a[i][k], a[j][k]);
44 }
45
46
47 for (j = i + 1; j <= n; ++j)
48 {
49 while (a[j][i])
50 {
51 ll d = a[i][i] / a[j][i];
52 for (int k = i; k <= n; ++k) a[i][k] -= d * a[j][k] % mod, a[i][k] %= mod;
53 res = -res;
54 for (int k = i; k <= n; ++k) swap(a[i][k], a[j][k]);
55 }
56 }
57 res = res * a[i][i] % mod;
58 // printf("res = %lld\n", res);
59 }
60 if (res < 0) res += mod;
61 // cout << "aa= "<<res <<endl;
62 return res;
63 }
64
65 int main()
66 {
67 //freopen("in.in","r",stdin);
68 //freopen("out.out","w",stdout);
69
70 int n, m; ll mod = 1000000000;
71 char mp[11][11]; int tot = 0;
72 int id[11][11];
73 scanf("%d %d", &n, &m);
74 for (int i = 1; i <= n; ++i)
75 {
76 for (int j = 1; j <= m; ++j)
77 {
78 scanf(" %c", &mp[i][j]);
79 if (mp[i][j] == '.') id[i][j] = ++tot;
80 }
81 }
82 ll a[N][N] = {0};
83 for (int i = 1; i < n; ++i)
84 {
85 for (int j = 1; j <= m; ++j)
86 {
87 if (mp[i][j] == '.' && mp[i + 1][j] == '.')
88 {
89 int x = id[i][j], y = id[i + 1][j];
90 a[x][y]--, a[y][x]--;
91 a[x][x]++, a[y][y]++;
92 }
93 }
94 }
95 for (int i = 1; i <= n; ++i)
96 {
97 for (int j = 1; j < m; ++j)
98 {
99 if (mp[i][j] == '.' && mp[i][j + 1] == '.')
100 {
101 int x = id[i][j], y = id[i][j + 1];
102 a[x][y]--, a[y][x]--;
103 a[x][x]++, a[y][y]++;
104 }
105 }
106 }
107 printf("%lld\n", Solve(a, tot - 1, mod));
108 return 0;
109 }
应用三: BZOJ 2467
这题数据范围比较小,可以暴力建图 然后跑Matrix tree。
另外可以直接推公式:
一共有4n个点, 5n条边, 所以要删去n - 1条边, 然后可以发现 每个五边形外面的4条边最多只能删一条。
根据鸽笼原理合法的解 一定是 有一个五边形删去了里面的那条边 和外面的某条边, 其余的五边形删去了任意一条边。
所以答案就是$4*n*5^{n-1}$
Matrix tree 代码:
1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4 #include <algorithm>
5 #include <vector>
6 #include <cmath>
7 #include <queue>
8 #include <map>
9 using namespace std;
10
11 #define X first
12 #define Y second
13 #define N 420
14 #define M 11
15
16 typedef long long ll;
17 const int Mod = 1000000007;
18 const int INF = 1 << 30;
19
20 void Add(int x, int y, int &tot, int a[N][N])
21 {
22 a[x][tot + 1] = a[tot + 1][x] = -1;
23 a[tot + 1][tot + 2] = a[tot + 2][tot + 1] = -1;
24 a[tot + 2][tot + 3] = a[tot + 3][tot + 2] = -1;
25 a[tot + 3][y] = a[y][tot + 3] = -1;
26 a[tot + 1][tot + 1] = a[tot + 2][tot + 2] = a[tot + 3][tot + 3] = 2;
27 a[x][x] = a[y][y] = 4; tot += 3;
28 a[x][y]--,a[y][x]--;
29 }
30
31 int Solve(int n, int a[N][N], int mod)
32 {
33
34 if (n == 0) return 1;
35 int res = 1;
36 for (int i = 1; i <= n; ++i)
37 {
38 for (int j = i; j <= n; ++j)
39 {
40 if (a[j][i] < 0)
41 {
42 res = -res;
43 for (int k = i; k <= n; ++k) a[j][k] = -a[j][k];
44 }
45 }
46 //cout << i << endl;
47 //for (int p = 1; p <= n; ++p)
48 // for (int q = 1; q <= n; ++q)
49 // printf("%d%c", a[p][q], q == n? '\n':' ');
50 //printf("\n");
51 int j;
52 for (j = i; j <= n && !a[j][i]; ++j);
53 if (j > n) return 0;
54 if (i != j)
55 {
56 res = -res;
57 for (int k = i; k <= n; ++k) swap(a[i][k], a[j][k]);
58 }
59
60 for (j = i + 1; j <= n; ++j)
61 {
62 while (a[j][i])
63 {
64 int d = a[i][i] / a[j][i];
65 for (int k = i; k <= n; ++k) a[i][k] -= d * a[j][k] % mod, a[i][k] %= mod;
66 res = -res;
67 for (int k = i; k <= n; ++k) swap(a[i][k], a[j][k]);
68 }
69 }
70 res = res * a[i][i] % mod;
71 }
72 if (res < 0) res += mod;
73 return res;
74 }
75
76 int main()
77 {
78 //freopen("in.in","r",stdin);
79 //freopen("out.out","w",stdout);
80
81 int T; scanf("%d", &T);
82 while (T--)
83 {
84 int n, mod = 2007, tot, a[N][N] = {0}; scanf("%d", &n); tot = n;
85 for (int i = 1; i < n; ++i) Add(i, i + 1, tot, a);
86 Add(n, 1, tot, a);
87 printf("%d\n", Solve(tot - 1, a, mod));
88 }
89 return 0;
90 }
应用四:BZOJ 1016
题目大意:求最小生成树个数。
性质一:无向图所有MST中,相同权值的边数一样多。
证明看https://blog.sengxian.com/solutions/bzoj-1016
性质二:对于任意MST,加入所有权值<=w的边后, 形成的森林连通性相同 。
证明:
考虑Kruskal算法的过程,我们首先会尽可能多的加入权值最小的边。这个过程相当于拿出所有权值最小的边,然后任意求一颗生成树,因此我们可以知道,能加入的权值最小的边的数量是一定的,而且加入这些边之后 形成的森林连通性相同。
结合性质一,对于任意一棵MST,因为它包含的权值最小的边数和做Kruskal算法求出的MST包含的边数是一样的,这些边又不能形成环,因此这些边形成的森林和 做Kruskal时形成的森林连通性是一样的。 对于任意MST,加入所有权值最小的边后, 形成的森林连通性相同 。
然后我们考虑把已经形成的联通块缩点, 考虑所有权值第二小的边,重复上面的过程,可以证明对于任意MST,加入所有权值<=w的边后, 形成的森林连通性相同 。
所以我们的算法就可以模拟这个过程, 把边按权值从小到大排好序, 每次加入权值相同的所有边, 形成一些连通块, 然后对于每个连通块, 跑Matrix tree 求出形成这个连通块有多少种方案。 统计好之后每个连通块缩点, 进行下一种权值的边的加边操作。 代码实现起来还是挺多细节的。
代码:
1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4 #include <algorithm>
5 #include <vector>
6 #include <cmath>
7 #include <queue>
8 #include <map>
9 using namespace std;
10
11 #define X first
12 #define Y second
13 #define N 120
14 #define M 1010
15
16 typedef long long ll;
17 const int Mod = 1000000007;
18 const int INF = 1 << 30;
19
20 struct Edge
21 {
22 int x, y, z;
23 bool operator < (const Edge &t)const{return z < t.z;}
24 }e[M];
25
26
27 int Solve(int n, int a[N][N], int mod)
28 {
29 if (n == 0) return 1;
30 int res = 1;
31 for (int i = 0; i < n; ++i)
32 {
33 for (int j = i; j < n; ++j)
34 {
35 if (a[j][i] < 0)
36 {
37 res = -res;
38 for (int k = i; k < n; ++k) a[j][k] = -a[j][k];
39 }
40 }
41 int j;
42 for (j = i; j < n && !a[j][i]; ++j);
43 if (j == n) return 0;
44 if (i != j)
45 {
46 res = -res;
47 for (int k = i; k < n; ++k) swap(a[i][k], a[j][k]);
48 }
49
50 for (j = i + 1; j < n; ++j)
51 {
52 while (a[j][i])
53 {
54 int d = a[i][i] / a[j][i];
55 for (int k = i; k < n; ++k) a[i][k] -= d * a[j][k] % mod, a[i][k] %= mod;
56 res = -res;
57 for (int k = i; k < n; ++k) swap(a[i][k], a[j][k]);
58 }
59 }
60 res = res * a[i][i] % mod;
61 }
62 if (res < 0) res += mod;
63 return res;
64 }
65
66
67 int father[N], id[N], pa[N], num[N];
68 vector<int> lis[N];
69 vector<pair<int, int> > ed[N];
70
71 int Find(int x)
72 {
73 if (father[x] == x) return x;
74 father[x] = Find(father[x]);
75 return father[x];
76 }
77
78 void Merge(int x, int y)
79 {
80 x = Find(x), y = Find(y);
81 if (x != y) father[x] = y;
82 }
83
84 int main()
85 {
86 //freopen("in.in","r",stdin);
87 //freopen("out.out","w",stdout);
88
89 int n, m, mod = 31011;
90 scanf("%d %d", &n, &m);
91 for (int i = 1; i <= m; ++i) scanf("%d %d %d", &e[i].x, &e[i].y, &e[i].z);
92 sort(e + 1, e + m + 1);
93
94 int res = 1, block = n;
95 for (int i = 1; i <= n; ++i) id[i] = i;
96 for (int l = 1, r; l <= m;)
97 {
98 for (r = l; r < m && e[r + 1].z == e[l].z; ++r);
99 for (int i = 1; i <= block; ++i) father[i] = i;
100 for (r = l; r < m && e[r + 1].z == e[l].z; ++r);
101 for (int i = l; i <= r; ++i) Merge(id[e[i].x], id[e[i].y]);
102
103 int tot = 0;
104 for (int i = 1; i <= block; ++i) if (father[i] == i) pa[i] = ++tot;
105 for (int i = 1; i <= block; ++i) pa[i] = pa[Find(i)];
106 for (int i = 1; i <= block; ++i) lis[pa[i]].push_back(i), num[i] = lis[pa[i]].size() - 1;
107 for (int i = l; i <= r; ++i)
108 {
109 int x = id[e[i].x], y = id[e[i].y];
110 if (x == y) continue;
111 ed[pa[x]].push_back(make_pair(num[x], num[y]));
112 }
113 for (int i = 1; i <= tot; ++i)
114 {
115 int a[N][N] = {0}, x, y;
116 for (int j = 0; j < ed[i].size(); ++j)
117 {
118 x = ed[i][j].X, y = ed[i][j].Y;
119 a[x][x]++, a[y][y]++;
120 a[x][y]--, a[y][x]--;
121 }
122 res = res * Solve(lis[i].size() - 1, a, mod) % mod;
123 }
124
125 for (int i = 1; i <= n; ++i) id[i] = pa[id[i]];
126 for (int i = 1; i <= tot; ++i) lis[i].clear(), ed[i].clear();
127 block = tot; l = r + 1;
128 }
129 if (block > 1) puts("0");
130 else printf("%d\n", res);
131 return 0;
132 }
应用五: BZOJ 3534 Matrix Tree Theorem 的扩展, 非常精彩的题。
题目大意:
给出一个无向图, 两点之间的连边会有一个概率, 求连成一颗树的概率。
这个题如果没有看过Matrix Tree Theorem定理的证明,只是记住结论 应该是做不出来的。。。
先来看一个简化版本:
给出一个无向图, 定义它的一棵生成树的权值为所有边权的乘积。 求所有生成树的权值和。 ( 原题还要考虑一些边不选的概率 这题只靠考虑选的边)
参考最前面给出的 证明Matrix Tree Theorem定理的文章。
原来是
A[i][j] = - (<i, j>之间的边数)
A[i][i] = 点i的度数
现在改成
A[i][j] = - (<i, j>之间的所有边权和)
A[i][i] = 和i相连的所有边的边权和
修改关联矩阵B的定义, 把1 改成 $e_j$的边权开根号后的值。
这里也做相应的修改, 把 -1 改成 -<i,j>之间的边权和, the degree of i 改成和i相连的所有边的边权和。
做了以上修改之后刚好就是所选的生成树的边权的乘积。
所以用修改后的A数组跑Matrix Tree就可以解决这个问题了。
当然原题 还要乘上 其他边不选的概率。 再对A数组做点小修改就好了。具体实现看代码:
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 #include <vector> 6 #include <cmath> 7 #include <queue> 8 #include <map> 9 using namespace std; 10 11 #define X first 12 #define Y second 13 #define N 100 14 #define M 11 15 16 typedef long long ll; 17 const int Mod=1000000007; 18 const int INF=1<<30; 19 20 const double eps = 1e-10; 21 22 double Solve(double a[N][N], int n) 23 { 24 if (n == 0) return 1; 25 double res = 1; 26 for (int i = 1; i <= n; ++i) 27 { 28 int j = i; 29 for (int k = i + 1; k <= n ; ++k) if (fabs(a[k][i]) > fabs(a[j][i])) j = k; 30 if (fabs(a[j][i]) < eps) return 0; 31 if (i != j) 32 { 33 res = -res; 34 for (int k = i; k <= n; ++k) swap(a[i][k], a[j][k]); 35 } 36 37 for (j = i + 1; j <= n; ++j) 38 { 39 double f = a[j][i] / a[i][i]; 40 for (int k = i; k <= n; ++k) a[j][k] -= f * a[i][k]; 41 } 42 res *= a[i][i]; 43 } 44 return res; 45 } 46 47 int main() 48 { 49 //freopen("in.in","r",stdin); 50 //freopen("out.out","w",stdout); 51 52 int n; 53 double a[N][N] = {0}, res = 1; 54 55 scanf("%d", &n); 56 for (int i = 1; i <= n; ++i) 57 { 58 for (int j = 1; j <= n; ++j) 59 { 60 scanf("%lf", &a[i][j]); 61 if (i == j) continue; 62 if (i < j && fabs(a[i][j] - 1) > eps) res *= 1 - a[i][j]; 63 if (fabs(a[i][j] - 1) > eps) a[i][j] = -a[i][j] / (1 - a[i][j]); 64 else a[i][j] = -a[i][j]; 65 } 66 } 67 for (int i = 1; i <= n; ++i) 68 for (int j = 1; j <= n; ++j) 69 if (i != j) a[i][i] -= a[i][j]; 70 printf("%.8lf\n", fabs(res * Solve(a, n - 1))); 71 return 0; 72 }