矩阵优化
知识基础:
矩阵乘法
矩阵快速幂
序列型动态规划
- 矩阵乘法
对于矩阵A和矩阵B
这两个矩阵的乘法当且仅当A的列数与B的行数相等时才能定义。
如A是m * n矩阵, B是n * p矩阵,他们的乘积C是一个m * p矩阵C = ( cij )
其中 cij = Σnk = 1 aik * bkj
记作 C = AB
注:矩阵乘法满足结合律(即 (AB)C = A(BC),这是快速幂的基础) 、
左结合律(即(A + B)C = AC + BC)、右结合律(即C(A + B) = CA + CB)
但不满足交换律!!!
- 矩阵快速幂
1 void matrix_pow(long long y){ 2 int i, j, k; 3 memset(ans, 0, sizeof(ans)); 4 for(i = 1; i <= n; i++) 5 ans[i][i] = 1;//单位矩阵建立 6 while(y > 0){ 7 if(y & 1){ 8 memset(temp, 0, sizeof(temp)); 9 for(i = 1; i <= n; i++) 10 for(j = 1; j <= n; j++) 11 for(k = 1; k <= n; k++) 12 temp[i][j] = (temp[i][j] + ans[i][k] * mat[k][j]) % P; 13 for(i = 1; i <= n; i++) 14 for(j = 1; j <= n; j++) 15 ans[i][j] = temp[i][j]; 16 //ans = (ans * mat) % P; 17 } 18 memset(temp, 0, sizeof(temp)); 19 int i, j, k; 20 for(i = 1; i <= n; i++) 21 for(j = 1; j <= n; j++) 22 for(k = 1; k <= n; k++) 23 temp[i][j] = (temp[i][j] + mat[i][k] * mat[k][j]) % P; 24 for(i = 1; i <= n; i++) 25 for(j = 1; j <= n; j++) 26 mat[i][j] = temp[i][j]; 27 // mat = (mat * mat) % P; 28 y >>= 1; 29 } 30 }
模板题 : 洛谷 P3390 【模板】矩阵快速幂
然后接可以开始矩阵优化了
拿斐波那契数列来举例吧
众所周知 斐波那契数列满足f(n) = f(n - 1) + f(n - 2)
边界为 f(1) = f(2) = 1
如果要求f(1e15) 那普通线性推导分分钟T爆
考虑优化:
由于得到 f(n) 需要它前两个值
若已知矩阵 { f(n - 1), f(n - 2) } 就可以导出f(n)
当然 也可以通过它建立矩阵{ f(n), f(n - 1) } 这样推导下去
推导方式就是让它乘上一个矩阵mat
f(n) = 1 * f(n - 1) + 1 * f(n - 2);
所以mat的第一列是{1, 1}
f(n - 1) = 1 * f(n - 1) + 0 * f(n - 2);
所以矩阵mat第二列是{1, 0}
这样我们得到了mat
1 1
1 0
推导n次 就相当于
{ f(2), f(1) } * (mat)n - 2 = { f(n), f(n - 1) }
模板题 :洛谷 P1939 【模板】矩阵加速(数列)
例题:洛谷 P1306 斐波那契公约数
竟然把结论手推出来了…【还是孤陋寡闻啊
gcd( f(a), f(b) ) = f( gcd(a, b) )
然后矩阵优化推斐波那契就行了
例题 : 洛谷 P1357 花园
本题使用状压
深搜预处理合理状态
建立矩阵进行状态转移
再传统矩阵加速
1 void update(int cnt, int st){ 2 okk[st] = 1; 3 int pre = st >> 1; 4 mat[pre][st] = 1; 5 if(cnt == l && !(st & 1)) return ; 6 mat[pre | (1 << (m - 1))][st] = 1; 7 } 8 void dfs(int cur, int cnt, int st){ 9 if(cur > m){ 10 update(cnt, st); return ; 11 } 12 dfs(cur + 1, cnt, st); 13 if(cnt < l) dfs(cur + 1, cnt + 1, st | (1 << (cur - 1))); 14 }
1 int main(){ 2 scanf("%lld%lld%lld", &n, &m, &l); 3 dfs(1, 0, 0); 4 lim = (1 << m); 5 qsort(n); 6 long long tot = 0; 7 for(int i = 0; i < lim; i++) 8 if(okk[i]) tot = (tot + ans[i][i]) % P; 9 printf("%lld", tot); 10 return 0; 11 }
记得取模 n会爆int