hdu4914 Linear recursive sequence

用矩阵求解线性递推式通项

用fft优化矩阵乘法

首先把递推式求解转化为矩阵求幂,再利用特征多项式f(λ)满足f(A) = 0,将矩阵求幂转化为多项式相乘,

最后利用傅里叶变换的高效算法(迭代取代递归)(参见算法导论)解决。

 

 

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 #include <map>
  5 #include <string>
  6 #include <vector>
  7 #include <set>
  8 #include <cmath>
  9 #include <ctime>
 10 #pragma comment(linker, "/STACK:102400000,102400000")
 11 using namespace std;
 12 #define lson (u << 1)
 13 #define rson (u << 1 | 1)
 14 typedef long long ll;
 15 const double eps = 1e-6;
 16 const double pi = acos(-1.0);
 17 const int maxn = 4e4 + 10;
 18 const int maxm = 1050;
 19 const int mod = 119;
 20 const int inf = 0x3f3f3f3f;
 21 
 22 int n, a, b, p, q;
 23 int size;
 24 int f[maxn], g[maxn];
 25 
 26 struct Complex{
 27     double ii, ij;//ii::real, ij::image
 28     Complex(double ii = 0, double ij = 0) : ii(ii), ij(ij) {}
 29 //    Complex clear() { this->ii = this->ij = 0; }
 30     Complex operator + (const Complex &rhs) const{
 31         return Complex(ii + rhs.ii, ij + rhs.ij);
 32     }
 33     Complex operator - (const Complex &rhs) const{
 34         return Complex(ii - rhs.ii, ij - rhs.ij);
 35     }
 36     Complex operator * (const Complex &rhs) const{
 37         return Complex(ii * rhs.ii - ij * rhs.ij, ii * rhs.ij + ij * rhs.ii);
 38     }
 39 };
 40 
 41 Complex a1[maxn], a2[maxn];
 42 
 43 void fft(Complex *src, int len, int rev){
 44     //len is power of 2
 45     //rev == 1::dft rev == -1::idft
 46     for(int i = 1, j = 0; i < len; i++){
 47         for(int k = len >> 1; k > (j ^= k); k >>= 1) ;
 48         if(i < j) swap(src[i], src[j]);
 49     }
 50     for(int i = 2; i <= len; i <<= 1){
 51         Complex wi(cos(2 * pi * rev / i), sin(2 * pi * rev / i));
 52         //(wi)^i = 1
 53         for(int j = 0; j < len; j += i){
 54             //using iteration insetad of recursion
 55             Complex w(1.0, 0.0);
 56             //w = (wi)^0
 57             for(int k = j; k < j + i / 2; k++){
 58                 Complex tem = w * src[k + i / 2];
 59                 src[k + i / 2] = src[k] - tem;
 60                 src[k] = src[k] + tem;
 61                 w = w * wi;
 62             }
 63         }
 64     }
 65     if(rev == -1){
 66         for(int i = 0; i < len; i++) src[i].ii = (src[i].ii / len + eps);
 67     }
 68 }
 69 
 70 void multi(int *src1, int *src2, int len){
 71     for(int i = 0; i < len; i++){
 72         a1[i].ii = a1[i].ij = a2[i].ii = a2[i].ij = 0;
 73         if(i < q){
 74             a1[i].ii = (double)src1[i];
 75             a2[i].ii = (double)src2[i];
 76         }
 77     }
 78     fft(a1, len, 1), fft(a2, len, 1);
 79     for(int i = 0; i < len; i++) a1[i] = a1[i] * a2[i];
 80     fft(a1, len, -1);
 81     for(int i = 0; i < len; i++) g[i] = (int)((ll)(a1[i].ii + eps) % mod);
 82     for(int i = 2 * q - 2; i >= q; i--){
 83         //this is because for the fisrt row in matrix A,
 84         //which satisfies ths (f(n + q),...,f(n))T = A((f(n + q - 1),...,f(n - 1))T)
 85         //only two elements are nonzero integers
 86         g[i - q] = (g[i - q] + g[i] * b) % mod;
 87         g[i - p] = (g[i - p] + g[i] * a) % mod;
 88     }
 89     memcpy(src1, g, sizeof(int) * q);
 90 }
 91 
 92 int tmp[maxn], ans[maxn];
 93 
 94 int main(){
 95     //freopen("in.txt", "r", stdin);
 96     while(~scanf("%d%d%d%d%d", &n, &a, &b, &p, &q)){
 97         a %= mod, b %= mod;
 98         f[0] = 1;
 99         for(int i = 1; i < q; i++){
100             f[i] = i < p ? a + b : a * f[i - p] + b;
101             f[i] %= mod;
102         }
103         if(n < q){
104             printf("%d\n", f[n]);
105             continue;
106         }
107         size = 1;
108         while(size <= (q - 1) * 2) size <<= 1;
109         memset(tmp, 0, sizeof tmp);
110         memset(ans, 0, sizeof ans);
111         ans[0] = tmp[1] = 1;
112         while(n){
113             if(n & 1) multi(ans, tmp, size);
114             multi(tmp, tmp, size);
115             n >>= 1;
116         }
117         int res = 0;
118         for(int i = 0; i < q; i++){
119             res = (res + ans[i] * f[i]) % mod;
120         }
121         printf("%d\n", res);
122     }
123     return 0;
124 }
View Code

 

posted @ 2015-10-21 15:52  astoninfer  阅读(408)  评论(0编辑  收藏  举报