BZOJ 3231: [Sdoi2008]递归数列

Time Limit: 1 Sec Memory Limit: 256 MB
Submit: 916 Solved: 408
[Submit][Status][Discuss]
Description

一个由自然数组成的数列按下式定义:
对于i <= k:ai = bi
对于i > k: ai = c1ai-1 + c2ai-2 + … + ckai-k
其中bj和 cj (1<=j<=k)是给定的自然数。写一个程序,给定自然数m <= n, 计算am + am+1 + am+2 + … + an, 并输出它除以给定自然数p的余数的值。
Input

由四行组成。
第一行是一个自然数k。
第二行包含k个自然数b1, b2,…,bk。
第三行包含k个自然数c1, c2,…,ck。
第四行包含三个自然数m, n, p。
Output

仅包含一行:一个正整数,表示(am + am+1 + am+2 + … + an) mod p的值。
Sample Input

2

1 1

1 1

2 10 1000003

Sample Output

142
HINT

对于100%的测试数据:

1<= k<=15

1 <= m <= n <= 1018

解题思路

这道题可以看成一个类似斐波那契数列的矩阵乘法,初始矩阵是一行长度为k+1,前k个元素为b[i],最后一个为前缀和。而转移矩阵是一个(k+1)*(k+1)的矩阵,前k-1列中a[i+1][i]为1,第k列前k行为c[k]~c[1]这k个元素,最后一行为0。第k+1列前k行为c[k]~c[1]这k个元素,最后一行 为1,这样就可以转移。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>

using namespace std;
const int MAXN = 20;
typedef long long LL;

int k;
LL b[MAXN],c[MAXN],p,n,m;
LL sum[MAXN],ans1,ans2;

struct Mat{
    LL a[MAXN][MAXN];
    Mat(){memset(a,0,sizeof(a));}
    Mat operator*(const Mat &h){
        Mat ret;
        for(register int i=1;i<=k+1;i++)
            for(register int j=1;j<=k+1;j++)
                for(register int o=1;o<=k+1;o++)
                    ret.a[i][j]=(ret.a[i][j]+a[i][o]*h.a[o][j])%p;
        return ret;
    }
}pre,ans,data;

inline Mat fast_pow(Mat A,LL b){
    Mat ret;
    for(register int i=1;i<=k+1;i++) ret.a[i][i]=1;
    for(;b;b>>=1){
        if(b&1) ret=ret*A;
        A=A*A;
    }
    return ret;
}

int main(){
    scanf("%d",&k);
    for(register int i=1;i<=k;i++) scanf("%lld",&b[i]);
    for(register int i=1;i<=k;i++) scanf("%lld",&c[i]);
    scanf("%lld%lld%lld",&m,&n,&p);
    for(register int i=1;i<=k;i++) pre.a[1][i]=b[i],sum[i]=(sum[i-1]+b[i])%p;
    pre.a[1][k+1]=sum[k];
    for(register int i=1;i<=k-1;i++) data.a[i+1][i]=1;
    for(register int i=1;i<=k;i++) data.a[i][k]=data.a[i][k+1]=c[k-i+1];
    data.a[k+1][k+1]=1;
//  for(register int i=1;i<=k+1;i++) cout<<pre.a[1][i]<<" ";
//  cout<<endl;
//  cout<<"------------>"<<endl;
//  for(register int i=1;i<=k+1;i++){
//      for(register int j=1;j<=k+1;j++)
//          cout<<data.a[i][j]<<" ";
//      cout<<endl;
//  }
    if(m<=k) ans1=sum[m-1];else ans1=((pre*fast_pow(data,m-k-1)).a[1][k+1]);
    if(n<=k) ans2=sum[n];else ans2=((pre*fast_pow(data,n-k)).a[1][k+1]);
//  cout<<ans1<<" "<<ans2<<endl;
    printf("%lld",((ans2-ans1)%p+p)%p);
    return 0;
}
posted @ 2018-07-28 08:32  Monster_Qi  阅读(130)  评论(0编辑  收藏  举报