BZOJ 3231: [Sdoi2008]递归数列
Time Limit: 1 Sec Memory Limit: 256 MB
Submit: 916 Solved: 408
[Submit][Status][Discuss]
Description
一个由自然数组成的数列按下式定义:
对于i <= k:ai = bi
对于i > k: ai = c1ai-1 + c2ai-2 + … + ckai-k
其中bj和 cj (1<=j<=k)是给定的自然数。写一个程序,给定自然数m <= n, 计算am + am+1 + am+2 + … + an, 并输出它除以给定自然数p的余数的值。
Input
由四行组成。
第一行是一个自然数k。
第二行包含k个自然数b1, b2,…,bk。
第三行包含k个自然数c1, c2,…,ck。
第四行包含三个自然数m, n, p。
Output
仅包含一行:一个正整数,表示(am + am+1 + am+2 + … + an) mod p的值。
Sample Input
2
1 1
1 1
2 10 1000003
Sample Output
142
HINT
对于100%的测试数据:
1<= k<=15
1 <= m <= n <= 1018
解题思路
这道题可以看成一个类似斐波那契数列的矩阵乘法,初始矩阵是一行长度为k+1,前k个元素为b[i],最后一个为前缀和。而转移矩阵是一个(k+1)*(k+1)的矩阵,前k-1列中a[i+1][i]为1,第k列前k行为c[k]~c[1]这k个元素,最后一行为0。第k+1列前k行为c[k]~c[1]这k个元素,最后一行 为1,这样就可以转移。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int MAXN = 20;
typedef long long LL;
int k;
LL b[MAXN],c[MAXN],p,n,m;
LL sum[MAXN],ans1,ans2;
struct Mat{
LL a[MAXN][MAXN];
Mat(){memset(a,0,sizeof(a));}
Mat operator*(const Mat &h){
Mat ret;
for(register int i=1;i<=k+1;i++)
for(register int j=1;j<=k+1;j++)
for(register int o=1;o<=k+1;o++)
ret.a[i][j]=(ret.a[i][j]+a[i][o]*h.a[o][j])%p;
return ret;
}
}pre,ans,data;
inline Mat fast_pow(Mat A,LL b){
Mat ret;
for(register int i=1;i<=k+1;i++) ret.a[i][i]=1;
for(;b;b>>=1){
if(b&1) ret=ret*A;
A=A*A;
}
return ret;
}
int main(){
scanf("%d",&k);
for(register int i=1;i<=k;i++) scanf("%lld",&b[i]);
for(register int i=1;i<=k;i++) scanf("%lld",&c[i]);
scanf("%lld%lld%lld",&m,&n,&p);
for(register int i=1;i<=k;i++) pre.a[1][i]=b[i],sum[i]=(sum[i-1]+b[i])%p;
pre.a[1][k+1]=sum[k];
for(register int i=1;i<=k-1;i++) data.a[i+1][i]=1;
for(register int i=1;i<=k;i++) data.a[i][k]=data.a[i][k+1]=c[k-i+1];
data.a[k+1][k+1]=1;
// for(register int i=1;i<=k+1;i++) cout<<pre.a[1][i]<<" ";
// cout<<endl;
// cout<<"------------>"<<endl;
// for(register int i=1;i<=k+1;i++){
// for(register int j=1;j<=k+1;j++)
// cout<<data.a[i][j]<<" ";
// cout<<endl;
// }
if(m<=k) ans1=sum[m-1];else ans1=((pre*fast_pow(data,m-k-1)).a[1][k+1]);
if(n<=k) ans2=sum[n];else ans2=((pre*fast_pow(data,n-k)).a[1][k+1]);
// cout<<ans1<<" "<<ans2<<endl;
printf("%lld",((ans2-ans1)%p+p)%p);
return 0;
}