[Libre]Shlw loves matrix I

学习一个很玄学很酷炫很不可做的新技巧。
对于一个 $ k $ 项的齐次线性递推式在OI里面我们很容易就可以写出一个转移矩阵 , 对于矩阵快速幂那套理论我们已经身经百战了。
然而毒瘤出题人在这道题里要求你做 $ k = 1000 $ , 这种时候直接计算一个$ k ^ 3 $的矩乘明显是要T飞的。
所以我们才有这道题目里面的技巧ORZ。

现在我们要求一个常次线性递推:

\[a_{n} = \sum_{i = 1}^{k}{h_{i} * a_{n - i}} \]

首先先引入一个 \(Cayley-Hamilton\) 定理 , 这个定理给我们保证了对于一个 $ k * k $ 的转移矩阵\(M\)
拥有特征多项式 $ p(\lambda) = det(\lambda I - M) $ , 且存在 $ p(M) = O $ 。
然后对这个\(p(\lambda)\)拉普拉斯展开 :

\[p(x) = x ^ {k} - \sum_{i = 1}^{k}{h_{i} * x^{k - i}} \]

于是我们就有 :

\[p(M) = M ^ {k} - \sum_{i = 1}^{k}{h_{i} * M ^ {k - i}} = 0 \]

特别的 , 对于这个 $ p(M) $ 我们可以把它叫做化零多项式 , 利用它我们可以快速计算 $ G(M) $
比如说对于一个多项式除法 : $ G(M) = D(M)p(M) + r(M) => G(M) = D(M)O + r(M) => G(M) = r(M) $
所以我们要求 \(G(M)\) 只需要让它对这个\(p(x)\)取模之后再丢\(M\)进去算就行了。

这是我们要求的是 $ M ^ {n - k +1} $ , 现在我们不加证明的说 $ M ^ {n - k +1} $ 可以被 $ I , M ^ {1} ,M ^ {2} , ... , M ^{k - 1} $ 线性组合表示。
显然 $ M^{i} $ 可以被 $ M^{i - j} * M^{j} $ 的线性组合表示 , 于是我们把矩乘表现为了线性组合相乘 , 即两个多项式相乘的形式。
同时这个线性组合的乘法我们用快速幂总共只要做 $ log n $ 下就能求到 $ n - k +1 $ ,每次乘完之后都要模化零多项式求出 $ M ^ {i} $ 的线性组合。
总复杂度即 $ O(k^2 log n) $


#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<stack>
#include<queue>
#include<set>
#include<vector>
#include<cstdio>
#define lson k << 1
#define rson k << 1 | 1

using namespace std;
typedef long long ll;
const int N = 4005 , Mod = 1e9 + 7 , G = 3 , INF = 0x3f3f3f3f;
double Pi = acos(-1);

struct Virt{
    double x , y;
    Virt(double _x = 0.0 , double _y = 0.0):x(_x) , y(_y){}
};

Virt operator + (Virt x , Virt y){return Virt(x.x + y.x , x.y + y.y);}
Virt operator - (Virt x , Virt y){return Virt(x.x - y.x , x.y - y.y);}
Virt operator * (Virt x , Virt y){return Virt(x.x * y.x - x.y * y.y , x.x * y.y + x.y * y.x);}

int read(){
	int x = 0 , f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -1 ; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = x * 10 + ch - '0'; ch = getchar();}
	return x * f;
}

int n , k;
ll f[N] , a[N] , b[N] , h[N] , ans[N];

void mul_mod(ll *A , ll *B , ll *C){
	static ll temp[N];
	for(int i = 0 ; i <= k * 2 ; i++) temp[i] = 0;
	for(int i = 0 ; i <= k ; i++)
		for(int j = 0 ; j <= k ; j++)
			temp[i + j] = (temp[i + j] + A[i] * B[j] % Mod) % Mod;
	for(int i = k * 2 ; i >= k ; i--)
		for(int j = k - 1 ; ~j ; j--)
			temp[i - k + j] = (temp[i - k + j] + Mod - temp[i] * f[j] % Mod) % Mod;
	for(int i = 0 ; i < k ; i++) C[i] = temp[i];
}

int main(){
	n = read(); k = read();
	for(int i = 1 ; i <= k ; i++)
		a[i] = read() , a[i] = (a[i] + Mod) % Mod; 
	for(int i = 1 ; i <= k ; i++)
		h[i] = read() , h[i] = (h[i] + Mod) % Mod;
	for(int i = 1 ; i <= k ; i++) f[k - i] = Mod - a[i];
	f[k] = 1; ans[0] = b[1] = 1;
	for(int i = n - k + 1 ; i ; i >>= 1){
		if(i & 1) mul_mod(b , ans , ans);
		mul_mod(b , b , b);
	}
	ll sum = 0;
	for(int i = k + 1 ; i <= k * 2 ; i++)
		for(int j = 1 ; j <= k ; j++)
			h[i] = (h[i] + h[i - j] * a[j] % Mod + Mod) % Mod;
	for(int i = 0 ; i < k ; i++)
		sum = (sum + h[i + k] * ans[i] % Mod + Mod) % Mod;
	printf("%lld\n",sum);		
}

posted @ 2018-05-04 21:24  FranceDisco  阅读(276)  评论(0编辑  收藏  举报