[模板]矩阵快速幂(斐波那契数列)

首先回顾矩阵乘法的定义

\[c_{ij}=\sum_{i=1}^{k}a_{ik} \times b_{kj} \]

显然,对\(F_{n}=F_{n-1}+F_{n-2}\)这样的柿子,我们可以用待定系数法求得递推矩阵:

\(f_{n}=\begin{bmatrix} F_{n} & F_{n-1} \end{bmatrix}(n > 1)\),解\(f_{n}=f_{n-1} \begin{bmatrix} a & b \\ c & d \end{bmatrix}\)

\[\begin{bmatrix} a & b \\ c & d \end{bmatrix} = \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} \]

于是

\[f_{n} = f_{2} \times\begin{bmatrix} 1 & 1\\ 1 &0 \end{bmatrix}^{n-2} \]

利用矩阵快速幂即可\(O(\log_{}{n})\)求解\(F_{n}\)
我们定义一个结构体来进行矩阵的表示,初始化,乘法运算等

struct Matrix {
  ll a[maxn][maxn];
  ll r, c;

  Matrix() {
    memset(a, 0, sizeof(a));//构造时进行初始化
  }

  void print() const { // for debug
    if (r < 0 || c < 0) return;
    for (int i = 1; i <= r; i++) {
      for (int j = 1; j <= c; j++) {
        printf("%lld ", a[i][j]);
      }
      puts("");
    }
  }

  Matrix operator * (const Matrix b) {
    Matrix res;
    if (c != b.r) {
      puts("error");
      exit(0);
    }
    res.r = r; res.c = b.c;
    //print(); b.print();
    for (int i = 1; i <= r; i++)
      for (int j = 1; j <= c; j++)
        for (int k = 1; k <= c; k++)
          res.a[i][j] = (res.a[i][j] + a[i][k] * b.a[k][j]) % mod;
    //res.print();
    return res;
  }
};

然后快速幂自然也是很简单的啦

Matrix matrix_pow(Matrix m, ll n) {
  Matrix res;
  res.r = res.c = m.r;
  for (int i = 1; i <= res.r; i++) res.a[i][i] = 1;//单位矩阵
  while (n) {
    if (n & 1)  res = res * m;
    m = m * m;
    n >>= 1;
  }
  return res;
}

附上完整代码

点击查看代码
#include <iostream>
#include <cstring>
#define ll long long
#define mod 1000000007
#define maxn 3
using namespace std;

struct Matrix {
  ll a[maxn][maxn];
  ll r, c;

  Matrix() {
    memset(a, 0, sizeof(a));
  }

  void print() const { // for debug
    if (r < 0 || c < 0) return;
    for (int i = 1; i <= r; i++) {
      for (int j = 1; j <= c; j++) {
        printf("%lld ", a[i][j]);
      }
      puts("");
    }
  }

  Matrix operator * (const Matrix b) {
    Matrix res;
    if (c != b.r) {
      puts("error");
      exit(0);
    }
    res.r = r; res.c = b.c;
    //print(); b.print();
    for (int i = 1; i <= r; i++)
      for (int j = 1; j <= c; j++)
        for (int k = 1; k <= c; k++)
          res.a[i][j] = (res.a[i][j] + a[i][k] * b.a[k][j]) % mod;
    //res.print();
    return res;
  }
};

Matrix matrix_pow(Matrix m, ll n) {
  Matrix res;
  res.r = res.c = m.r;
  for (int i = 1; i <= res.r; i++) res.a[i][i] = 1;
  while (n) {
    if (n & 1)  res = res * m;
    m = m * m;
    n >>= 1;
  }
  return res;
}

ll n;
Matrix f, base, I;
void init() {
  f.r = 1; f.c = 2;
  f.a[1][1] = f.a[1][2] = 1;
  base.r = base.c = 2;
  base.a[1][1] = base.a[1][2] = base.a[2][1] = 1;
  base.a[2][2] = 0;
}


int main() {
  while (cin >> n) {
    init();
    //(f * matrix_pow(base, n - 2)).print();
    if (n > 2) cout << (f * matrix_pow(base, n - 2)).a[1][1] << endl;
    else cout << 1 << endl;
  }
  return 0;
}
posted @ 2021-10-06 15:00  _vv123  阅读(70)  评论(0编辑  收藏  举报