Cayley-Hamilton定理与矩阵快速幂优化、常系数齐次线性递推优化
原文链接www.cnblogs.com/zhouzhendong/p/Cayley-Hamilton.html
Cayley-Hamilton定理与矩阵快速幂优化、常系数齐次线性递推优化
引入
在开始本文之前,我们先用一个例题作为引入。
- 给定一个 \(n \times n\) 的矩阵 \(M\) , 求 \(M ^ k\) 。
- \(n\leq 50, k\leq 10 ^ {50000}\) 。
注意到 \(n\) 十分小,但是 $ \log k$ 非常大。如果使用传统的矩阵快速幂,时间复杂度为 \(O(n ^ 3 \log k )\) ,难以接受。
但是运用 Cayley-Hamilton定理 来优化矩阵快速幂,可以做到 \(O(n ^ 4+n^2\log k)\) 甚至更优秀的复杂度\(^1\)。
- cly 说他看到有人说可以优化到 \(O(n^3 + n^2 \log k)\) 。但是不知道怎么优化。
Cayley-Hamilton定理
设 \(M\) 为一个 \(n\) 阶矩阵,定义矩阵 \(M\) 的特征多项式为
其中 \(x\) 可以属于一些域,包括但不限于复数域。
由于 \(|ME - M| = |0| = 0\) ,所以
矩阵快速幂
如果我们得到了 \(f(M)\) ,那么,将任意一个矩阵减去任意倍数的 \(f(M)\) 后值不变。
即:令 \(g(x) = x^k\),
考虑如何求解 \(g \bmod f\)。
类比对整数域下取模的快速幂做法,考虑对多项式 \(x\) 做快速幂,对 \(f\) 取模,直接实现的时间复杂度为 \(O(n ^ 2 \log k)\),如果采用 \(FFT\) 优化乘法,并利用多项式取模的做法实现取模,那么时间复杂度为 \(O(n \log n \log k)\) 。
接下来我们来讨论如何求 \(f(M)\) 。
将 \(n\) 个值代入 \(x\) 求行列式,再插值得到 \(f(M)\) ,时间复杂度 $O(n ^ 4) $ ,注意这里要求代入的值存在乘法逆元。
得到 \(g \bmod f\) 之后只需要将所有 \(M\) 的幂代入即可得到 $M ^ k $,这里时间复杂度也是 $O(n ^ 4) $ 。
综上所述,总时间复杂度 \(O(n ^ 4)\) 。
线性递推
回归本源,当我们要做矩阵快速幂的原因,往往是为了快速实现线性递推。由于线性递推问题存在特殊性,我们可以通过 Cayley-Hamilton定理 来得到更优秀的做法。
假设线性递推数列满足
我们将线性递推的矩阵写出来:
假设初始向量是列向量 \(B\) ,矩阵是 \(M\) ,那么
则我们要求的是
接下来我们来求 \(M\) 的特征多项式。
于是我们就直接得到了特征多项式的系数。
于是
假设结果为
因为 \(B M ^ i [a] = B[a + i]\),而这里 \(i < k\),我们要求的是 \(BM^i[0]\) ,所以我们只需要知道 \(B[0] \cdots B[k-1]\) 即可。类似地,只要我们预处理 B 数列的前 \(2k\) 项,就可以得到整个 \(BM^i\) 列向量了。
求单个值的时间复杂度为 \(O(k ^ 2\log n)\) 。
模板题 BZOJ4161
代码
#include <bits/stdc++.h>
#define clr(x) memset(x,0,sizeof x)
#define For(i,a,b) for (int i=a;i<=b;i++)
#define Fod(i,b,a) for (int i=b;i>=a;i--)
#define fi first
#define se second
#define pb(x) push_back(x)
#define mp(x,y) make_pair(x,y)
#define outval(x) printf(#x" = %d\n",x)
#define outtag(x) puts("---------------"#x"---------------")
#define outarr(a,L,R) printf(#a"[%d..%d] = ",L,R);\
For(_x,L,R)printf("%d ",a[_x]);puts("")
using namespace std;
typedef long long LL;
LL read(){
LL x=0,f=0;
char ch=getchar();
while (!isdigit(ch))
f|=ch=='-',ch=getchar();
while (isdigit(ch))
x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return f?-x:x;
}
const int N=2005*2,mod=1e9+7;
int n,k;
int a[N],b[N];
void Add(int &x,int y){
if ((x+=y)>=mod)
x-=mod;
}
void Del(int &x,int y){
if ((x-=y)<0)
x+=mod;
}
int c[N];
void Mul(int *x,int *y){
static int z[N];
clr(z);
For(i,0,k-1)
For(j,0,k-1)
Add(z[i+j],(LL)x[i]*y[j]%mod);
Fod(i,2*k-2,k){
if (!z[i])
continue;
For(j,1,k)
Add(z[i-j],(LL)a[j]*z[i]%mod);
}
For(i,0,k-1)
x[i]=z[i];
}
void GetPoly(){
static int x[N];
int y=n;
clr(x),clr(c),c[0]=x[1]=1;
for (;y;y>>=1,Mul(x,x))
if (y&1)
Mul(c,x);
}
int main(){
n=read(),k=read();
For(i,1,k)
a[i]=(read()+mod)%mod;
For(i,0,k-1)
b[i]=(read()+mod)%mod;
GetPoly();
int ans=0;
For(i,0,k-1)
Add(ans,(LL)b[i]*c[i]%mod);
cout<<ans<<endl;
return 0;
}