计蒜客模拟赛D2T1 蒜头君的兔子:矩阵快速幂

题目链接:https://nanti.jisuanke.com/t/16442

题意:

  有个人在第一年送了你一对1岁的兔子。这种兔子刚生下来的时候算0岁,当它在2~10岁的时候,每年都会生下一对兔子,并且它在10岁那年生完兔子后就会挂掉。现在让你算出第t年兔子的总数(不算那一年10岁的兔子)。

 

题解:

  我们用一个1*10的矩阵代表某一年的兔子数量,第k列上的数字n代表今年有n只k岁的兔子。

  那么初始矩阵是这样的:

  

  接下来考虑怎样构造特殊矩阵。

  有两个转移关系:

    第二年0岁的兔子数 = 第二年2~10岁的兔子数之和 = 今年1~9岁的兔子数之和

    第二年k (1<=k<=10) 岁的兔子数 = 今年k-1岁的兔子数

  也就是这样转移:

  

  b[0] = a[1] + a[2] + ... + a[9]

  b[1] = a[0]

  b[2] = a[1]

  ...

  b[9] = a[8]

  b[10] = a[9]

  那么特殊矩阵也就出来了:

  

  所以第t年的矩阵ans = 初始矩阵start * ( 特殊矩阵special ^ (t-1) )

  优化:由于在整个过程中根本没有用到每年10岁的兔子数量,所以可以省去初始矩阵的第10列,以及特殊矩阵的第10列&第10行。

 

AC Code:

#include <iostream>
#include <stdio.h>
#include <string.h>
#define MAX_L 15
#define MOD 1000000007

using namespace std;

struct Mat
{
    int n;
    int m;
    long long v[MAX_L][MAX_L];
    Mat()
    {
        memset(v,0,sizeof(v));
        n=0;
        m=0;
    }
};

int t;
long long sum=0;

Mat make_unit(int k)
{
    Mat mat;
    mat.n=k;
    mat.m=k;
    for(int i=0;i<k;i++)
    {
        mat.v[i][i]=1;
    }
    return mat;
}

Mat make_start()
{
    Mat mat;
    mat.n=1;
    mat.m=10;
    mat.v[0][1]=1;
    return mat;
}

Mat make_special()
{
    Mat mat;
    mat.n=10;
    mat.m=10;
    for(int i=1;i<=9;i++)
    {
        mat.v[i][0]=1;
        mat.v[i-1][i]=1;
    }
    return mat;
}

Mat mul_mat(const Mat &a,const Mat &b)
{
    Mat c;
    c.n=a.n;
    c.m=b.m;
    for(int i=0;i<a.n;i++)
    {
        for(int j=0;j<b.m;j++)
        {
            for(int k=0;k<a.m;k++)
            {
                c.v[i][j]+=(a.v[i][k]*b.v[k][j])%MOD;
                c.v[i][j]%=MOD;
            }
        }
    }
    return c;
}

Mat quick_pow_mat(Mat mat,long long k)
{
    Mat ans;
    ans=make_unit(mat.n);
    while(k)
    {
        if(k&1)
        {
            ans=mul_mat(ans,mat);
        }
        k>>=1;
        mat=mul_mat(mat,mat);
    }
    return ans;
}

void read()
{
    cin>>t;
}

void solve()
{
    Mat start=make_start();
    Mat special=make_special();
    Mat ans=mul_mat(start,quick_pow_mat(special,t-1));
    for(int i=0;i<=9;i++)
    {
        sum=(sum+ans.v[0][i])%MOD;
    }
}

void print()
{
    cout<<sum<<endl;
}

int main()
{
    read();
    solve();
    print();
}

 

posted @ 2017-07-31 02:37  Leohh  阅读(941)  评论(0编辑  收藏  举报