【算法】Matrix - Tree 矩阵树定理 & 题目总结

  最近集中学习了一下矩阵树定理,自己其实还是没有太明白原理(证明)类的东西,但想在这里总结一下应用中的一些细节,矩阵树定理的一些引申等等。

  首先,矩阵树定理用于求解一个图上的生成树个数。实现方式是:\(A\)为邻接矩阵,\(D\)为度数矩阵,则基尔霍夫(Kirchhoff)矩阵即为:\(K = D - A\)。具体实现中,记 \(a\) 为Kirchhoff矩阵,则若存在 \(E(u, v)\) ,则\(a[u][u] ++, a[v][v] ++, a[u][v] --, a[v][u] --\) 。即\(a[i][i]\) 为 \(i\) 点的度数,\(a[i][j]\) 为 \(i, j\)之间边的条数的相反数。

  这样构成的矩阵的行列式的值,就为生成树的个数。而求解行列式的快速方法为使用高斯消元进行消元消处上三角矩阵,则有对角线上的值的乘积 = 行列式的值。一般而言求解生成树个数的题目数量会非常庞大,需要取模处理。取模处理中,不能出现小数,于是使用辗转相除法:(其中因为消的是行列式,所以与消方程有所不同。交换两行行列式的值变号,且消元只能将一行的数 * k 之后加到别的行上。)

int Gauss()
{
    int ans = 1;
    for(int i = 1; i < tot; i ++)
    {
        for(int j = i + 1; j < tot; j ++)
            while(f[j][i])
            {
                int t = f[i][i] / f[j][i];
                for(int k = i; k < tot; k ++)
                    f[i][k] = (f[i][k] - t * f[j][k] + mod) % mod;
                swap(f[i], f[j]);
                ans = - ans;
            }
        ans = (ans * f[i][i]) % mod;
    }
    return (ans + mod) % mod;
}

  变元矩阵树定理:求所有生成树的总边积的和。和矩阵树的求法相同,不过行列式中 \(a[i][i]\) 记录的是总的边权和,\(a[i][j]\) 记录 \(i, j\) 之间边权的相反数。

  以下为几道题目:

    1.HEOI2015 小Z的房间     2.SHOI2016 黑暗前的幻想乡

    3.SDOI2014 重建         4.JSOI2008 最小生成树计数

  1.HEOI2015 小Z的房间(妥妥的模板题一个)

#include <bits/stdc++.h>
using namespace std;
#define maxn 90
#define int long long 
#define mod 1000000000
int n, m, f[maxn][maxn];
int tot, Map[maxn][maxn];

int read()
{
    int x = 0, k = 1;
    char c;
    c = getchar();
    while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * k;
}

void add(int x, int y)
{
    if(x > y) return;
    f[x][x] ++, f[y][y] ++;
    f[x][y] --, f[y][x] --;
}

int Gauss()
{
    int ans = 1;
    for(int i = 1; i < tot; i ++)
    {
        for(int j = i + 1; j < tot; j ++)
            while(f[j][i])
            {
                int t = f[i][i] / f[j][i];
                for(int k = i; k < tot; k ++)
                    f[i][k] = (f[i][k] - t * f[j][k] + mod) % mod;
                swap(f[i], f[j]);
                ans = - ans;
            }
        ans = (ans * f[i][i]) % mod;
    }
    return (ans + mod) % mod;
}

signed main()
{
    n = read(), m = read();
    for(int i = 1; i <= n; i ++)
    {
        char c;
        for(int j = 1; j <= m; j ++)
        {
            cin >> c;
            if(c == '.') Map[i][j] = ++ tot;
        }
    }
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= m; j ++)
        {
            int tem, u;
            if(!(u = Map[i][j])) continue;
            if(tem = Map[i - 1][j]) add(u, tem);
            if(tem = Map[i + 1][j]) add(u, tem);
            if(tem = Map[i][j - 1]) add(u, tem);
            if(tem = Map[i][j + 1]) add(u, tem);
        }
    printf("%lld\n", Gauss());
    return 0;
}

  2.SHOI2016黑暗前的幻想乡

  容斥+矩阵树定理。与模板的不同之处在于每一家公司都要参与修建,则合法方案数 = 总的方案数 - 有一个公司未修建的方案数 + 有两个公司未修建的方案数……暴力重构矩阵求解即可。

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int mod = 1000000007;
int n;
ll g[20][20];
vector<pair<int , int > > q[20];

int read()
{
    int x = 0, k = 1;
    char c;
    c = getchar();
    while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * k;
}

int Gauss()
{
    ll ans = 1;
    for(int i = 1; i < n; i ++)
    {
        for(int j = i + 1; j < n; j ++)
            while(g[j][i])
            {
                ll t = g[i][i] / g[j][i];
                for(int k = i; k < n; k ++)
                    g[i][k] = (g[i][k] - g[j][k] * t) % mod;
                swap(g[i], g[j]);
                ans = -ans;
            }
        ans = (ans * g[i][i]) % mod;
        if(!ans) return 0;
    }
    return (ans + mod) % mod;
}

int main()
{
    n = read();
    for(int i = 1; i < n; i ++)
    {
        int m = read();
        for(int j = 1; j <= m; j ++)
        {
            int x = read(), y = read();
            q[i].push_back(make_pair(x, y));
        }
    }
    int ans = 0, CNST = 1 << (n - 1);
    for(int i = 0; i < CNST; i ++)
    {
        int cnt = 0; memset(g, 0, sizeof(g));
        for(int j = 1; j < n; j ++)
            if(i & (1 << (j - 1)))
            {
                for(int k = 0; k < q[j].size(); k ++)
                {
                    int x = q[j][k].first, y = q[j][k].second;
                    g[x][x] ++, g[y][y] ++;
                    g[x][y] --, g[y][x] --;
                }
                cnt ++;
            }
        if((n - cnt) & 1) ans = (ans + Gauss()) % mod;
        else ans = (ans - Gauss() + mod) % mod;
    }
    printf("%d\n", ans);
    return 0;
}

  3.SDOI2014重建

  化式子 + 变元矩阵树定理。将概率的式子写出来变形即可得到矩阵树定理求 \(\prod \frac{p(u, v)}{1 - p(u, v)}\)

#include <bits/stdc++.h>
using namespace std;
#define maxn 100
#define db double
#define eps 0.000001
int n;
db ans = 1.0, a[maxn][maxn];

db Gauss(int n)
{
    db ans = 1.0;
    for(int i = 1; i <= n; i ++)
    {    
        for(int j = i + 1; j <= n; j ++)
        {
            int t = i;
            if(fabs(a[j][i]) > fabs(a[t][i])) t = j;
            if(t != i) swap(a[t], a[i]), ans = -ans;
        } 
        for(int j = i + 1; j <= n; j ++)
        {
            db t = a[j][i] / a[i][i];
            for(int k = i; k <= n; k ++)
                a[j][k] -= t * a[i][k];
        }
        ans *= a[i][i];
    }
    return fabs(ans);
}

int main()
{
    scanf("%d", &n);
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= n; j ++)
            scanf("%lf", &a[i][j]);
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= n; j ++)
        {
            db t = fabs(1.0 - a[i][j]) < eps ? eps : (1.0 - a[i][j]);
            if(i < j) ans *= t;
            a[i][j] = a[i][j] / t;
        }
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= n; j ++)
            if(i != j) { a[i][i] += a[i][j], a[i][j] = -a[i][j]; }
    printf("%.10lf\n", Gauss(n - 1) * ans);
    return 0;
} 

  4.JSOI2008最小生成树计数

  这题虽然最早年,然而也最强啊……个人认为这位博主解释得很好了 Z-Y-Y-S的博客

  两个性质 mark 一下:

  

  

 

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
#define maxn 200
#define mod 31011
int n, m, ans = 1, tmp[maxn];
int sum, fa[maxn], set[maxn];
int a[maxn][maxn];

struct edge
{
    int u, v, w;
}E[maxn * 20], e[maxn * 20];

int read()
{
    int x = 0, k = 1;
    char c;
    c = getchar();
    while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * k;
}

bool cmp(edge a, edge b) { return a.w < b.w; }

int find(int x) { return set[x] == x ? x : set[x] = find(set[x]); }
int find2(int x) { return fa[x] == x ? x : fa[x] = find2(fa[x]); }

int Gauss(int n)
{
    int ans = 1;
    for(int i = 1; i <= n; i ++)
        for(int j = 1; j <= n; j ++)
            a[i][j] = (a[i][j] + mod) % mod;
    for(int i = 1; i <= n; i ++)
    {
        for(int j = i + 1; j <= n; j ++)
            while(a[j][i])
            {
                int t = a[i][i] / a[j][i];
                for(int k = i; k <= n; k ++)
                    a[i][k] = (a[i][k] - 1ll * t * a[j][k] % mod + mod) % mod;
                swap(a[i], a[j]); ans = - ans;
            }
        ans = 1ll * ans * a[i][i] % mod;
    }
    return (ans + mod) % mod;
}

void Cal(int S, int T)
{
    int cnt = 0;
    for(int i = S; i <= T; i ++)
    {
        e[i] = E[i];
        int p = find(e[i].u), q = find(e[i].v);
        e[i].u = p, e[i].v = q;
        if(p == q) continue;
        tmp[++ cnt] = p, tmp[++ cnt] = q;
    }
    sort(tmp + 1, tmp + 1 + cnt);
    cnt = unique(tmp + 1, tmp + cnt + 1) - tmp - 1;
    memset(a, 0, sizeof(a));
    for(int i = 1; i <= cnt; i ++) fa[i] = i;
    for(int i = S; i <= T; i ++)
    {
        if(e[i].u == e[i].v) continue;
        int p = find(e[i].u), q = find(e[i].v);
        if(p != q) -- sum, set[p] = q;
        int u = lower_bound(tmp + 1, tmp + cnt + 1, e[i].u) - tmp;
        int v = lower_bound(tmp + 1, tmp + cnt + 1, e[i].v) - tmp;
        a[u][u] ++, a[v][v] ++;
        a[u][v] --, a[v][u] --;
        p = find2(u), q = find2(v);
        if(p != q) fa[p] = q;
    }
    for(int i = 2; i <= cnt; i ++)
        if(find2(i) != find2(i - 1))
        {
            int p = find2(i), q = find2(i - 1);
            a[p][p] ++, a[q][q] ++;
            a[p][q] --, a[q][p] --;
            fa[p] = q;
        }
    ans = 1ll * ans * Gauss(cnt - 1) % mod;
}

int main()
{
    n = read(), m = read();
    for(int i = 1; i <= m; i ++)
        E[i].u = read(), E[i].v = read(), E[i].w = read();
    sort(E + 1, E + 1 + m, cmp);
    for(int i = 1; i <= n; i ++) set[i] = i;
    sum = n; 
    for(int i = 1, j; i <= m; i = j)
    {
        for(j = i; j <= m; j ++)
            if(E[j].w != E[i].w) break;
        if(j - i > 1) Cal(i, j - 1);
        else 
        {
            int p = find(E[i].u), q = find(E[i].v);
            if(p != q) set[p] = q;
            sum --;
        }
    }
    if(sum > 1) printf("0");
    else printf("%d\n", ans);
    return 0;
} 

 

posted @ 2018-05-20 18:59  Twilight_Sx  阅读(6917)  评论(7编辑  收藏  举报